Skip to content

[flang][fir]: Add conversion of fir.iterate_while to scf.while. #152439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 14, 2025

Conversation

terapines-osc-mlir
Copy link
Contributor

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Aug 7, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 7, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Terapines MLIR (terapines-osc-mlir)

Changes

This commmit is a supplement for #140374.
RFC:https://discourse.llvm.org/t/rfc-add-fir-affine-optimization-fir-pass-pipeline/86190/6


Full diff: https://github.com/llvm/llvm-project/pull/152439.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/Transforms/FIRToSCF.cpp (+88-2)
  • (added) flang/test/Fir/FirToSCF/iter-while.fir (+99)
diff --git a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
index 1902757e83bf3..b779a21089549 100644
--- a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
@@ -88,6 +88,91 @@ struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
   }
 };
 
+struct IterWhileConversion : public OpRewritePattern<fir::IterWhileOp> {
+  using OpRewritePattern<fir::IterWhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(fir::IterWhileOp iterWhileOp,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = iterWhileOp.getLoc();
+    Value lowerBound = iterWhileOp.getLowerBound();
+    Value upperBound = iterWhileOp.getUpperBound();
+    Value step = iterWhileOp.getStep();
+
+    Value okInit = iterWhileOp.getIterateIn();
+    ValueRange iterArgs = iterWhileOp.getInitArgs();
+
+    SmallVector<Value> initVals;
+    initVals.push_back(lowerBound);
+    initVals.push_back(okInit);
+    initVals.append(iterArgs.begin(), iterArgs.end());
+
+    SmallVector<Type> loopTypes;
+    loopTypes.push_back(lowerBound.getType());
+    loopTypes.push_back(okInit.getType());
+    for (auto val : iterArgs)
+      loopTypes.push_back(val.getType());
+
+    auto scfWhileOp = scf::WhileOp::create(rewriter, loc, loopTypes, initVals);
+    rewriter.createBlock(&scfWhileOp.getBefore(), scfWhileOp.getBefore().end(),
+                         loopTypes,
+                         SmallVector<Location>(loopTypes.size(), loc));
+
+    rewriter.createBlock(&scfWhileOp.getAfter(), scfWhileOp.getAfter().end(),
+                         loopTypes,
+                         SmallVector<Location>(loopTypes.size(), loc));
+
+    {
+      rewriter.setInsertionPointToStart(&scfWhileOp.getBefore().front());
+      auto args = scfWhileOp.getBefore().getArguments();
+      auto iv = args[0];
+      auto ok = args[1];
+
+      Value inductionCmp = mlir::arith::CmpIOp::create(
+          rewriter, loc, mlir::arith::CmpIPredicate::sle, iv, upperBound);
+      Value cmp = mlir::arith::AndIOp::create(rewriter, loc, inductionCmp, ok);
+
+      mlir::scf::ConditionOp::create(rewriter, loc, cmp, args);
+    }
+
+    {
+      rewriter.setInsertionPointToStart(&scfWhileOp.getAfter().front());
+      auto args = scfWhileOp.getAfter().getArguments();
+      auto iv = args[0];
+
+      mlir::IRMapping mapping;
+      for (auto [oldArg, newVal] :
+           llvm::zip(iterWhileOp.getBody()->getArguments(), args))
+        mapping.map(oldArg, newVal);
+
+      for (auto &op : iterWhileOp.getBody()->without_terminator())
+        rewriter.clone(op, mapping);
+
+      auto resultOp =
+          cast<fir::ResultOp>(iterWhileOp.getBody()->getTerminator());
+      auto results = resultOp.getResults();
+
+      SmallVector<Value> yieldedVals;
+
+      Value nextIv = mlir::arith::AddIOp::create(rewriter, loc, iv, step);
+      yieldedVals.push_back(nextIv);
+
+      for (auto val : results.drop_front()) {
+        if (mapping.contains(val)) {
+          yieldedVals.push_back(mapping.lookup(val));
+        } else {
+          yieldedVals.push_back(val);
+        }
+      }
+
+      mlir::scf::YieldOp::create(rewriter, loc, yieldedVals);
+    }
+
+    rewriter.replaceOp(iterWhileOp, scfWhileOp);
+    return success();
+  }
+};
+
 void copyBlockAndTransformResult(PatternRewriter &rewriter, Block &srcBlock,
                                  Block &dstBlock) {
   Operation *srcTerminator = srcBlock.getTerminator();
@@ -130,9 +215,10 @@ struct IfConversion : public OpRewritePattern<fir::IfOp> {
 
 void FIRToSCFPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
-  patterns.add<DoLoopConversion, IfConversion>(patterns.getContext());
+  patterns.add<DoLoopConversion, IterWhileConversion, IfConversion>(
+      patterns.getContext());
   ConversionTarget target(getContext());
-  target.addIllegalOp<fir::DoLoopOp, fir::IfOp>();
+  target.addIllegalOp<fir::DoLoopOp, fir::IterWhileOp, fir::IfOp>();
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
   if (failed(
           applyPartialConversion(getOperation(), target, std::move(patterns))))
diff --git a/flang/test/Fir/FirToSCF/iter-while.fir b/flang/test/Fir/FirToSCF/iter-while.fir
new file mode 100644
index 0000000000000..a5de48f2ba848
--- /dev/null
+++ b/flang/test/Fir/FirToSCF/iter-while.fir
@@ -0,0 +1,99 @@
+// RUN: fir-opt %s --fir-to-scf | FileCheck %s
+
+// CHECK-LABEL:   func.func @test_simple_iterate_while_1() -> (index, i1, i16, i32) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 11 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 22 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant true
+// CHECK:           %[[VAL_4:.*]] = arith.constant 123 : i16
+// CHECK:           %[[VAL_5:.*]] = arith.constant 456 : i32
+// CHECK:           %[[VAL_6:.*]]:4 = scf.while (%[[VAL_7:.*]] = %[[VAL_0]], %[[VAL_8:.*]] = %[[VAL_3]], %[[VAL_9:.*]] = %[[VAL_4]], %[[VAL_10:.*]] = %[[VAL_5]]) : (index, i1, i16, i32) -> (index, i1, i16, i32) {
+// CHECK:             %[[VAL_11:.*]] = arith.cmpi sle, %[[VAL_7]], %[[VAL_1]] : index
+// CHECK:             %[[VAL_12:.*]] = arith.andi %[[VAL_11]], %[[VAL_8]] : i1
+// CHECK:             scf.condition(%[[VAL_12]]) %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] : index, i1, i16, i32
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_13:.*]]: index, %[[VAL_14:.*]]: i1, %[[VAL_15:.*]]: i16, %[[VAL_16:.*]]: i32):
+// CHECK:             %[[VAL_17:.*]] = arith.constant true
+// CHECK:             %[[VAL_18:.*]] = arith.constant 22 : i16
+// CHECK:             %[[VAL_19:.*]] = arith.constant 33 : i32
+// CHECK:             %[[VAL_20:.*]] = arith.addi %[[VAL_13]], %[[VAL_2]] : index
+// CHECK:             scf.yield %[[VAL_20]], %[[VAL_17]], %[[VAL_18]], %[[VAL_19]] : index, i1, i16, i32
+// CHECK:           }
+// CHECK:           return %[[VAL_21:.*]]#0, %[[VAL_21]]#1, %[[VAL_21]]#2, %[[VAL_21]]#3 : index, i1, i16, i32
+// CHECK:         }
+func.func @test_simple_iterate_while_1() -> (index, i1, i16, i32) {
+  %lo = arith.constant 11 : index
+  %up = arith.constant 22 : index
+  %step = arith.constant 2 : index
+  %ok = arith.constant 1 : i1
+  %val1 = arith.constant 123 : i16
+  %val2 = arith.constant 456 : i32
+
+  %res:4 = fir.iterate_while (%i = %lo to %up step %step) and (%c = %ok) iter_args(%v1 = %val1, %v2 = %val2) -> (index, i1, i16, i32) {
+    %new_c = arith.constant 1 : i1
+    %new_v1 = arith.constant 22 : i16
+    %new_v2 = arith.constant 33 : i32
+    fir.result %i, %new_c, %new_v1, %new_v2 : index, i1, i16, i32
+  }
+
+  return %res#0, %res#1, %res#2, %res#3 : index, i1, i16, i32
+}
+
+// CHECK-LABEL:   func.func @test_simple_iterate_while_2(
+// CHECK-SAME:        %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: i1, %[[ARG3:.*]]: i32) -> (index, i1, i32) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_1:.*]]:3 = scf.while (%[[VAL_2:.*]] = %[[ARG0]], %[[VAL_3:.*]] = %[[ARG2]], %[[VAL_4:.*]] = %[[ARG3]]) : (index, i1, i32) -> (index, i1, i32) {
+// CHECK:             %[[VAL_5:.*]] = arith.cmpi sle, %[[VAL_2]], %[[ARG1]] : index
+// CHECK:             %[[VAL_6:.*]] = arith.andi %[[VAL_5]], %[[VAL_3]] : i1
+// CHECK:             scf.condition(%[[VAL_6]]) %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] : index, i1, i32
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_7:.*]]: index, %[[VAL_8:.*]]: i1, %[[VAL_9:.*]]: i32):
+// CHECK:             %[[VAL_10:.*]] = arith.constant 123 : i32
+// CHECK:             %[[VAL_11:.*]] = arith.constant true
+// CHECK:             %[[VAL_12:.*]] = arith.addi %[[VAL_7]], %[[VAL_0]] : index
+// CHECK:             scf.yield %[[VAL_12]], %[[VAL_11]], %[[VAL_10]] : index, i1, i32
+// CHECK:           }
+// CHECK:           return %[[VAL_13:.*]]#0, %[[VAL_13]]#1, %[[VAL_13]]#2 : index, i1, i32
+// CHECK:         }
+func.func @test_simple_iterate_while_2(%start: index, %stop: index, %cond: i1, %val: i32) -> (index, i1, i32) {
+  %step = arith.constant 1 : index
+
+  %res:3 = fir.iterate_while (%i = %start to %stop step %step) and (%ok = %cond) iter_args(%x = %val) -> (index, i1, i32) {
+    %new_x = arith.constant 123 : i32
+    %new_ok = arith.constant 1 : i1
+    fir.result %i, %new_ok, %new_x : index, i1, i32
+  }
+
+  return %res#0, %res#1, %res#2 : index, i1, i32
+}
+
+// CHECK-LABEL:   func.func @test_zero_iterations() -> (index, i1, i8) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 10 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 5 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant true
+// CHECK:           %[[VAL_4:.*]] = arith.constant 42 : i8
+// CHECK:           %[[VAL_5:.*]]:3 = scf.while (%[[VAL_6:.*]] = %[[VAL_0]], %[[VAL_7:.*]] = %[[VAL_3]], %[[VAL_8:.*]] = %[[VAL_4]]) : (index, i1, i8) -> (index, i1, i8) {
+// CHECK:             %[[VAL_9:.*]] = arith.cmpi sle, %[[VAL_6]], %[[VAL_1]] : index
+// CHECK:             %[[VAL_10:.*]] = arith.andi %[[VAL_9]], %[[VAL_7]] : i1
+// CHECK:             scf.condition(%[[VAL_10]]) %[[VAL_6]], %[[VAL_7]], %[[VAL_8]] : index, i1, i8
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_11:.*]]: index, %[[VAL_12:.*]]: i1, %[[VAL_13:.*]]: i8):
+// CHECK:             %[[VAL_14:.*]] = arith.addi %[[VAL_11]], %[[VAL_2]] : index
+// CHECK:             scf.yield %[[VAL_14]], %[[VAL_12]], %[[VAL_13]] : index, i1, i8
+// CHECK:           }
+// CHECK:           return %[[VAL_15:.*]]#0, %[[VAL_15]]#1, %[[VAL_15]]#2 : index, i1, i8
+// CHECK:         }
+func.func @test_zero_iterations() -> (index, i1, i8) {
+  %lo = arith.constant 10 : index
+  %up = arith.constant 5 : index
+  %step = arith.constant 1 : index
+  %ok = arith.constant 1 : i1
+  %x = arith.constant 42 : i8
+
+  %res:3 = fir.iterate_while (%i = %lo to %up step %step) and (%c = %ok) iter_args(%xv = %x) -> (index, i1, i8) {
+    fir.result %i, %c, %xv : index, i1, i8
+  }
+
+  return %res#0, %res#1, %res#2 : index, i1, i8
+}

@clementval clementval requested a review from rscottmanley August 7, 2025 05:43

auto &afterBlock = *rewriter.createBlock(
&scfWhileOp.getAfter(), scfWhileOp.getAfter().end(), loopTypes,
SmallVector<Location>(loopTypes.size(), loc));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the original block meets the relevant requirements, I think we can use scfWhileOp.getAfter().takeBody instead of creating a new block.

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@@ -10,6 +10,7 @@
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Transforms/DialectConversion.h"
#include <mlir/Support/LLVM.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use #include "mlir/Support/LLVM.h" and sort the headers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This include is introduced by IDE, got it removed now, thanks 🤝

Copy link
Contributor

@c8ef c8ef left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG.

&scfWhileOp.getBefore(), scfWhileOp.getBefore().end(), loopTypes,
SmallVector<Location>(loopTypes.size(), loc));

auto beforeArgs = scfWhileOp.getBefore().getArguments();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know what type this is... is it a vector of values? range of block arguments? -- the autos below would be okay if this one was specified.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The auto is replaced by explicit type now.


auto beforeArgs = scfWhileOp.getBefore().getArguments();
auto beforeIv = beforeArgs[0];
auto beforeOk = beforeArgs[1];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is "beforeOk"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the early exit flag in fir.iterate_while. (the example in document use the name "ok", so I just used it)
It is now renamed to earlyExitInBefore. thanks 🤝

Copy link
Contributor

@NexMing NexMing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines 97 to 100
Location loc = iterWhileOp.getLoc();
Value lowerBound = iterWhileOp.getLowerBound();
Value upperBound = iterWhileOp.getUpperBound();
Value step = iterWhileOp.getStep();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a good reason to not follow the same style than all files in flang/lib/Lower? We usually have expanded namespace. Not specific to this PR but it would be nice to follow that same style.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, there was not a good reason. The code has been modified to use expanded namespace. Thanks ! 🤝

Copy link

github-actions bot commented Aug 13, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@NexMing NexMing merged commit c164e63 into llvm:main Aug 14, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants