-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[flang][fir]: Add conversion of fir.iterate_while
to scf.while
.
#152439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-flang-fir-hlfir Author: Terapines MLIR (terapines-osc-mlir) ChangesThis commmit is a supplement for #140374. Full diff: https://github.com/llvm/llvm-project/pull/152439.diff 2 Files Affected:
diff --git a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
index 1902757e83bf3..b779a21089549 100644
--- a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
@@ -88,6 +88,91 @@ struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
}
};
+struct IterWhileConversion : public OpRewritePattern<fir::IterWhileOp> {
+ using OpRewritePattern<fir::IterWhileOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(fir::IterWhileOp iterWhileOp,
+ PatternRewriter &rewriter) const override {
+
+ Location loc = iterWhileOp.getLoc();
+ Value lowerBound = iterWhileOp.getLowerBound();
+ Value upperBound = iterWhileOp.getUpperBound();
+ Value step = iterWhileOp.getStep();
+
+ Value okInit = iterWhileOp.getIterateIn();
+ ValueRange iterArgs = iterWhileOp.getInitArgs();
+
+ SmallVector<Value> initVals;
+ initVals.push_back(lowerBound);
+ initVals.push_back(okInit);
+ initVals.append(iterArgs.begin(), iterArgs.end());
+
+ SmallVector<Type> loopTypes;
+ loopTypes.push_back(lowerBound.getType());
+ loopTypes.push_back(okInit.getType());
+ for (auto val : iterArgs)
+ loopTypes.push_back(val.getType());
+
+ auto scfWhileOp = scf::WhileOp::create(rewriter, loc, loopTypes, initVals);
+ rewriter.createBlock(&scfWhileOp.getBefore(), scfWhileOp.getBefore().end(),
+ loopTypes,
+ SmallVector<Location>(loopTypes.size(), loc));
+
+ rewriter.createBlock(&scfWhileOp.getAfter(), scfWhileOp.getAfter().end(),
+ loopTypes,
+ SmallVector<Location>(loopTypes.size(), loc));
+
+ {
+ rewriter.setInsertionPointToStart(&scfWhileOp.getBefore().front());
+ auto args = scfWhileOp.getBefore().getArguments();
+ auto iv = args[0];
+ auto ok = args[1];
+
+ Value inductionCmp = mlir::arith::CmpIOp::create(
+ rewriter, loc, mlir::arith::CmpIPredicate::sle, iv, upperBound);
+ Value cmp = mlir::arith::AndIOp::create(rewriter, loc, inductionCmp, ok);
+
+ mlir::scf::ConditionOp::create(rewriter, loc, cmp, args);
+ }
+
+ {
+ rewriter.setInsertionPointToStart(&scfWhileOp.getAfter().front());
+ auto args = scfWhileOp.getAfter().getArguments();
+ auto iv = args[0];
+
+ mlir::IRMapping mapping;
+ for (auto [oldArg, newVal] :
+ llvm::zip(iterWhileOp.getBody()->getArguments(), args))
+ mapping.map(oldArg, newVal);
+
+ for (auto &op : iterWhileOp.getBody()->without_terminator())
+ rewriter.clone(op, mapping);
+
+ auto resultOp =
+ cast<fir::ResultOp>(iterWhileOp.getBody()->getTerminator());
+ auto results = resultOp.getResults();
+
+ SmallVector<Value> yieldedVals;
+
+ Value nextIv = mlir::arith::AddIOp::create(rewriter, loc, iv, step);
+ yieldedVals.push_back(nextIv);
+
+ for (auto val : results.drop_front()) {
+ if (mapping.contains(val)) {
+ yieldedVals.push_back(mapping.lookup(val));
+ } else {
+ yieldedVals.push_back(val);
+ }
+ }
+
+ mlir::scf::YieldOp::create(rewriter, loc, yieldedVals);
+ }
+
+ rewriter.replaceOp(iterWhileOp, scfWhileOp);
+ return success();
+ }
+};
+
void copyBlockAndTransformResult(PatternRewriter &rewriter, Block &srcBlock,
Block &dstBlock) {
Operation *srcTerminator = srcBlock.getTerminator();
@@ -130,9 +215,10 @@ struct IfConversion : public OpRewritePattern<fir::IfOp> {
void FIRToSCFPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
- patterns.add<DoLoopConversion, IfConversion>(patterns.getContext());
+ patterns.add<DoLoopConversion, IterWhileConversion, IfConversion>(
+ patterns.getContext());
ConversionTarget target(getContext());
- target.addIllegalOp<fir::DoLoopOp, fir::IfOp>();
+ target.addIllegalOp<fir::DoLoopOp, fir::IterWhileOp, fir::IfOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
diff --git a/flang/test/Fir/FirToSCF/iter-while.fir b/flang/test/Fir/FirToSCF/iter-while.fir
new file mode 100644
index 0000000000000..a5de48f2ba848
--- /dev/null
+++ b/flang/test/Fir/FirToSCF/iter-while.fir
@@ -0,0 +1,99 @@
+// RUN: fir-opt %s --fir-to-scf | FileCheck %s
+
+// CHECK-LABEL: func.func @test_simple_iterate_while_1() -> (index, i1, i16, i32) {
+// CHECK: %[[VAL_0:.*]] = arith.constant 11 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 22 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant true
+// CHECK: %[[VAL_4:.*]] = arith.constant 123 : i16
+// CHECK: %[[VAL_5:.*]] = arith.constant 456 : i32
+// CHECK: %[[VAL_6:.*]]:4 = scf.while (%[[VAL_7:.*]] = %[[VAL_0]], %[[VAL_8:.*]] = %[[VAL_3]], %[[VAL_9:.*]] = %[[VAL_4]], %[[VAL_10:.*]] = %[[VAL_5]]) : (index, i1, i16, i32) -> (index, i1, i16, i32) {
+// CHECK: %[[VAL_11:.*]] = arith.cmpi sle, %[[VAL_7]], %[[VAL_1]] : index
+// CHECK: %[[VAL_12:.*]] = arith.andi %[[VAL_11]], %[[VAL_8]] : i1
+// CHECK: scf.condition(%[[VAL_12]]) %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] : index, i1, i16, i32
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_13:.*]]: index, %[[VAL_14:.*]]: i1, %[[VAL_15:.*]]: i16, %[[VAL_16:.*]]: i32):
+// CHECK: %[[VAL_17:.*]] = arith.constant true
+// CHECK: %[[VAL_18:.*]] = arith.constant 22 : i16
+// CHECK: %[[VAL_19:.*]] = arith.constant 33 : i32
+// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_13]], %[[VAL_2]] : index
+// CHECK: scf.yield %[[VAL_20]], %[[VAL_17]], %[[VAL_18]], %[[VAL_19]] : index, i1, i16, i32
+// CHECK: }
+// CHECK: return %[[VAL_21:.*]]#0, %[[VAL_21]]#1, %[[VAL_21]]#2, %[[VAL_21]]#3 : index, i1, i16, i32
+// CHECK: }
+func.func @test_simple_iterate_while_1() -> (index, i1, i16, i32) {
+ %lo = arith.constant 11 : index
+ %up = arith.constant 22 : index
+ %step = arith.constant 2 : index
+ %ok = arith.constant 1 : i1
+ %val1 = arith.constant 123 : i16
+ %val2 = arith.constant 456 : i32
+
+ %res:4 = fir.iterate_while (%i = %lo to %up step %step) and (%c = %ok) iter_args(%v1 = %val1, %v2 = %val2) -> (index, i1, i16, i32) {
+ %new_c = arith.constant 1 : i1
+ %new_v1 = arith.constant 22 : i16
+ %new_v2 = arith.constant 33 : i32
+ fir.result %i, %new_c, %new_v1, %new_v2 : index, i1, i16, i32
+ }
+
+ return %res#0, %res#1, %res#2, %res#3 : index, i1, i16, i32
+}
+
+// CHECK-LABEL: func.func @test_simple_iterate_while_2(
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: i1, %[[ARG3:.*]]: i32) -> (index, i1, i32) {
+// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_1:.*]]:3 = scf.while (%[[VAL_2:.*]] = %[[ARG0]], %[[VAL_3:.*]] = %[[ARG2]], %[[VAL_4:.*]] = %[[ARG3]]) : (index, i1, i32) -> (index, i1, i32) {
+// CHECK: %[[VAL_5:.*]] = arith.cmpi sle, %[[VAL_2]], %[[ARG1]] : index
+// CHECK: %[[VAL_6:.*]] = arith.andi %[[VAL_5]], %[[VAL_3]] : i1
+// CHECK: scf.condition(%[[VAL_6]]) %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] : index, i1, i32
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_7:.*]]: index, %[[VAL_8:.*]]: i1, %[[VAL_9:.*]]: i32):
+// CHECK: %[[VAL_10:.*]] = arith.constant 123 : i32
+// CHECK: %[[VAL_11:.*]] = arith.constant true
+// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_7]], %[[VAL_0]] : index
+// CHECK: scf.yield %[[VAL_12]], %[[VAL_11]], %[[VAL_10]] : index, i1, i32
+// CHECK: }
+// CHECK: return %[[VAL_13:.*]]#0, %[[VAL_13]]#1, %[[VAL_13]]#2 : index, i1, i32
+// CHECK: }
+func.func @test_simple_iterate_while_2(%start: index, %stop: index, %cond: i1, %val: i32) -> (index, i1, i32) {
+ %step = arith.constant 1 : index
+
+ %res:3 = fir.iterate_while (%i = %start to %stop step %step) and (%ok = %cond) iter_args(%x = %val) -> (index, i1, i32) {
+ %new_x = arith.constant 123 : i32
+ %new_ok = arith.constant 1 : i1
+ fir.result %i, %new_ok, %new_x : index, i1, i32
+ }
+
+ return %res#0, %res#1, %res#2 : index, i1, i32
+}
+
+// CHECK-LABEL: func.func @test_zero_iterations() -> (index, i1, i8) {
+// CHECK: %[[VAL_0:.*]] = arith.constant 10 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 5 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant true
+// CHECK: %[[VAL_4:.*]] = arith.constant 42 : i8
+// CHECK: %[[VAL_5:.*]]:3 = scf.while (%[[VAL_6:.*]] = %[[VAL_0]], %[[VAL_7:.*]] = %[[VAL_3]], %[[VAL_8:.*]] = %[[VAL_4]]) : (index, i1, i8) -> (index, i1, i8) {
+// CHECK: %[[VAL_9:.*]] = arith.cmpi sle, %[[VAL_6]], %[[VAL_1]] : index
+// CHECK: %[[VAL_10:.*]] = arith.andi %[[VAL_9]], %[[VAL_7]] : i1
+// CHECK: scf.condition(%[[VAL_10]]) %[[VAL_6]], %[[VAL_7]], %[[VAL_8]] : index, i1, i8
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_11:.*]]: index, %[[VAL_12:.*]]: i1, %[[VAL_13:.*]]: i8):
+// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_11]], %[[VAL_2]] : index
+// CHECK: scf.yield %[[VAL_14]], %[[VAL_12]], %[[VAL_13]] : index, i1, i8
+// CHECK: }
+// CHECK: return %[[VAL_15:.*]]#0, %[[VAL_15]]#1, %[[VAL_15]]#2 : index, i1, i8
+// CHECK: }
+func.func @test_zero_iterations() -> (index, i1, i8) {
+ %lo = arith.constant 10 : index
+ %up = arith.constant 5 : index
+ %step = arith.constant 1 : index
+ %ok = arith.constant 1 : i1
+ %x = arith.constant 42 : i8
+
+ %res:3 = fir.iterate_while (%i = %lo to %up step %step) and (%c = %ok) iter_args(%xv = %x) -> (index, i1, i8) {
+ fir.result %i, %c, %xv : index, i1, i8
+ }
+
+ return %res#0, %res#1, %res#2 : index, i1, i8
+}
|
c9355af
to
0e1f67b
Compare
|
||
auto &afterBlock = *rewriter.createBlock( | ||
&scfWhileOp.getAfter(), scfWhileOp.getAfter().end(), loopTypes, | ||
SmallVector<Location>(loopTypes.size(), loc)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the original block meets the relevant requirements, I think we can use scfWhileOp.getAfter().takeBody
instead of creating a new block.
0e1f67b
to
a07ddad
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
@@ -10,6 +10,7 @@ | |||
#include "flang/Optimizer/Transforms/Passes.h" | |||
#include "mlir/Dialect/SCF/IR/SCF.h" | |||
#include "mlir/Transforms/DialectConversion.h" | |||
#include <mlir/Support/LLVM.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use #include "mlir/Support/LLVM.h"
and sort the headers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This include is introduced by IDE, got it removed now, thanks 🤝
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG.
&scfWhileOp.getBefore(), scfWhileOp.getBefore().end(), loopTypes, | ||
SmallVector<Location>(loopTypes.size(), loc)); | ||
|
||
auto beforeArgs = scfWhileOp.getBefore().getArguments(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont know what type this is... is it a vector of values? range of block arguments? -- the auto
s below would be okay if this one was specified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The auto
is replaced by explicit type now.
|
||
auto beforeArgs = scfWhileOp.getBefore().getArguments(); | ||
auto beforeIv = beforeArgs[0]; | ||
auto beforeOk = beforeArgs[1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is "beforeOk"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the early exit flag in fir.iterate_while
. (the example in document use the name "ok", so I just used it)
It is now renamed to earlyExitInBefore
. thanks 🤝
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Location loc = iterWhileOp.getLoc(); | ||
Value lowerBound = iterWhileOp.getLowerBound(); | ||
Value upperBound = iterWhileOp.getUpperBound(); | ||
Value step = iterWhileOp.getStep(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a good reason to not follow the same style than all files in flang/lib/Lower? We usually have expanded namespace. Not specific to this PR but it would be nice to follow that same style.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, there was not a good reason. The code has been modified to use expanded namespace. Thanks ! 🤝
ffcdedb
to
7c4ba34
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
7c4ba34
to
104e678
Compare
104e678
to
e9efb92
Compare
This commmit is a supplement for #140374.
RFC:https://discourse.llvm.org/t/rfc-add-fir-affine-optimization-fir-pass-pipeline/86190/6