diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp index 9565b305b564..b011a8851681 100644 --- a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp +++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp @@ -8,11 +8,15 @@ #include "PassDetail.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "clang/CIR/Dialect/IR/CIRDialect.h" #include "clang/CIR/Dialect/Passes.h" +#include "llvm/ADT/SmallVector.h" using namespace mlir; using namespace cir; @@ -107,6 +111,85 @@ struct RemoveTrivialTry : public OpRewritePattern { } }; +/// Simplify suitable ternary operations into select operations. +/// +/// For now we only simplify those ternary operations whose true and false +/// branches directly yield a value or a constant. That is, both of the true and +/// the false branch must either contain a cir.yield operation as the only +/// operation in the branch, or contain a cir.const operation followed by a +/// cir.yield operation that yields the constant value. +/// +/// For example, we will simplify the following ternary operation: +/// +/// %0 = cir.ternary (%condition, true { +/// %1 = cir.const ... +/// cir.yield %1 +/// } false { +/// cir.yield %2 +/// }) +/// +/// into the following sequence of operations: +/// +/// %1 = cir.const ... +/// %0 = cir.select if %condition then %1 else %2 +struct SimplifyTernary final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TernaryOp op, + PatternRewriter &rewriter) const override { + if (op->getNumResults() != 1) + return mlir::failure(); + + if (!isSimpleTernaryBranch(op.getTrueRegion()) || + !isSimpleTernaryBranch(op.getFalseRegion())) + return mlir::failure(); + + mlir::cir::YieldOp trueBranchYieldOp = mlir::cast( + op.getTrueRegion().front().getTerminator()); + mlir::cir::YieldOp falseBranchYieldOp = mlir::cast( + op.getFalseRegion().front().getTerminator()); + auto trueValue = trueBranchYieldOp.getArgs()[0]; + auto falseValue = falseBranchYieldOp.getArgs()[0]; + + rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op); + rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op); + rewriter.eraseOp(trueBranchYieldOp); + rewriter.eraseOp(falseBranchYieldOp); + rewriter.replaceOpWithNewOp(op, op.getCond(), + trueValue, falseValue); + + return mlir::success(); + } + +private: + bool isSimpleTernaryBranch(mlir::Region ®ion) const { + if (!region.hasOneBlock()) + return false; + + mlir::Block &onlyBlock = region.front(); + auto &ops = onlyBlock.getOperations(); + + // The region/block could only contain at most 2 operations. + if (ops.size() > 2) + return false; + + if (ops.size() == 1) { + // The region/block only contain a cir.yield operation. + return true; + } + + // Check whether the region/block contains a cir.const followed by a + // cir.yield that yields the value. + auto yieldOp = mlir::cast(onlyBlock.getTerminator()); + auto yieldValueDefOp = mlir::dyn_cast_if_present( + yieldOp.getArgs()[0].getDefiningOp()); + if (!yieldValueDefOp || yieldValueDefOp->getBlock() != &onlyBlock) + return false; + + return true; + } +}; + struct SimplifySelect : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -171,6 +254,7 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) { RemoveEmptyScope, RemoveEmptySwitch, RemoveTrivialTry, + SimplifyTernary, SimplifySelect >(patterns.getContext()); // clang-format on @@ -186,8 +270,9 @@ void CIRSimplifyPass::runOnOperation() { getOperation()->walk([&](Operation *op) { // CastOp here is to perform a manual `fold` in // applyOpPatternsAndFold - if (isa(op)) + if (isa( + op)) ops.push_back(op); }); diff --git a/clang/test/CIR/CodeGen/binop.cpp b/clang/test/CIR/CodeGen/binop.cpp index 29f6e89282b0..045e78ccf021 100644 --- a/clang/test/CIR/CodeGen/binop.cpp +++ b/clang/test/CIR/CodeGen/binop.cpp @@ -32,13 +32,7 @@ void b1(bool a, bool b) { // CHECK: cir.ternary(%3, true // CHECK-NEXT: %7 = cir.load %1 -// CHECK-NEXT: cir.ternary(%7, true -// CHECK-NEXT: cir.const #true -// CHECK-NEXT: cir.yield -// CHECK-NEXT: false { -// CHECK-NEXT: cir.const #false -// CHECK-NEXT: cir.yield -// CHECK: cir.yield +// CHECK-NEXT: cir.yield %7 // CHECK-NEXT: false { // CHECK-NEXT: cir.const #false // CHECK-NEXT: cir.yield @@ -48,11 +42,6 @@ void b1(bool a, bool b) { // CHECK-NEXT: cir.yield // CHECK-NEXT: false { // CHECK-NEXT: %7 = cir.load %1 -// CHECK-NEXT: cir.ternary(%7, true -// CHECK-NEXT: cir.const #true -// CHECK-NEXT: cir.yield -// CHECK-NEXT: false { -// CHECK-NEXT: cir.const #false // CHECK-NEXT: cir.yield void b2(bool a) { @@ -90,16 +79,10 @@ void b3(int a, int b, int c, int d) { // CHECK-NEXT: %13 = cir.load %2 // CHECK-NEXT: %14 = cir.load %3 // CHECK-NEXT: %15 = cir.cmp(eq, %13, %14) -// CHECK-NEXT: cir.ternary(%15, true -// CHECK: %9 = cir.load %0 -// CHECK-NEXT: %10 = cir.load %1 -// CHECK-NEXT: %11 = cir.cmp(eq, %9, %10) -// CHECK-NEXT: %12 = cir.ternary(%11, true { -// CHECK: }, false { -// CHECK-NEXT: %13 = cir.load %2 -// CHECK-NEXT: %14 = cir.load %3 -// CHECK-NEXT: %15 = cir.cmp(eq, %13, %14) -// CHECK-NEXT: %16 = cir.ternary(%15, true +// CHECK-NEXT: cir.yield %15 +// CHECK-NEXT: }, false { +// CHECK-NEXT: %13 = cir.const #false +// CHECK-NEXT: cir.yield %13 void testFloatingPointBinOps(float a, float b) { a * b; diff --git a/clang/test/CIR/CodeGen/ternary.cpp b/clang/test/CIR/CodeGen/ternary.cpp index 6475add8e2b4..5c17ef5d1a74 100644 --- a/clang/test/CIR/CodeGen/ternary.cpp +++ b/clang/test/CIR/CodeGen/ternary.cpp @@ -12,16 +12,12 @@ int x(int y) { // CHECK: %2 = cir.load %0 : !cir.ptr, !s32i // CHECK: %3 = cir.const #cir.int<0> : !s32i // CHECK: %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool -// CHECK: %5 = cir.ternary(%4, true { -// CHECK: %7 = cir.const #cir.int<3> : !s32i -// CHECK: cir.yield %7 : !s32i -// CHECK: }, false { -// CHECK: %7 = cir.const #cir.int<5> : !s32i -// CHECK: cir.yield %7 : !s32i -// CHECK: }) : (!cir.bool) -> !s32i -// CHECK: cir.store %5, %1 : !s32i, !cir.ptr -// CHECK: %6 = cir.load %1 : !cir.ptr, !s32i -// CHECK: cir.return %6 : !s32i +// CHECK: %5 = cir.const #cir.int<3> : !s32i +// CHECK: %6 = cir.const #cir.int<5> : !s32i +// CHECK: %7 = cir.select if %4 then %5 else %6 : (!cir.bool, !s32i, !s32i) -> !s32i +// CHECK: cir.store %7, %1 : !s32i, !cir.ptr +// CHECK: %8 = cir.load %1 : !cir.ptr, !s32i +// CHECK: cir.return %8 : !s32i // CHECK: } typedef enum { diff --git a/clang/test/CIR/Transforms/ternary-fold.cir b/clang/test/CIR/Transforms/ternary-fold.cir new file mode 100644 index 000000000000..6778d4744a32 --- /dev/null +++ b/clang/test/CIR/Transforms/ternary-fold.cir @@ -0,0 +1,60 @@ +// RUN: cir-opt -cir-simplify -o %t.cir %s +// RUN: FileCheck --input-file=%t.cir %s + +!s32i = !cir.int + +module { + cir.func @fold_ternary(%arg0: !s32i, %arg1: !s32i) -> !s32i { + %0 = cir.const #cir.bool : !cir.bool + %1 = cir.ternary (%0, true { + cir.yield %arg0 : !s32i + }, false { + cir.yield %arg1 : !s32i + }) : (!cir.bool) -> !s32i + cir.return %1 : !s32i + } + + // CHECK: cir.func @fold_ternary(%{{.+}}: !s32i, %[[ARG:.+]]: !s32i) -> !s32i { + // CHECK-NEXT: cir.return %[[ARG]] : !s32i + // CHECK-NEXT: } + + cir.func @simplify_ternary(%arg0 : !cir.bool, %arg1 : !s32i) -> !s32i { + %0 = cir.ternary (%arg0, true { + %1 = cir.const #cir.int<42> : !s32i + cir.yield %1 : !s32i + }, false { + cir.yield %arg1 : !s32i + }) : (!cir.bool) -> !s32i + cir.return %0 : !s32i + } + + // CHECK: cir.func @simplify_ternary(%[[ARG0:.+]]: !cir.bool, %[[ARG1:.+]]: !s32i) -> !s32i { + // CHECK-NEXT: %[[#A:]] = cir.const #cir.int<42> : !s32i + // CHECK-NEXT: %[[#B:]] = cir.select if %[[ARG0]] then %[[#A]] else %[[ARG1]] : (!cir.bool, !s32i, !s32i) -> !s32i + // CHECK-NEXT: cir.return %[[#B]] : !s32i + // CHECK-NEXT: } + + cir.func @non_simplifiable_ternary(%arg0 : !cir.bool) -> !s32i { + %0 = cir.alloca !s32i, !cir.ptr, ["a", init] + %1 = cir.ternary (%arg0, true { + %2 = cir.const #cir.int<42> : !s32i + cir.yield %2 : !s32i + }, false { + %3 = cir.load %0 : !cir.ptr, !s32i + cir.yield %3 : !s32i + }) : (!cir.bool) -> !s32i + cir.return %1 : !s32i + } + + // CHECK: cir.func @non_simplifiable_ternary(%[[ARG0:.+]]: !cir.bool) -> !s32i { + // CHECK-NEXT: %[[#A:]] = cir.alloca !s32i, !cir.ptr, ["a", init] + // CHECK-NEXT: %[[#B:]] = cir.ternary(%[[ARG0]], true { + // CHECK-NEXT: %[[#C:]] = cir.const #cir.int<42> : !s32i + // CHECK-NEXT: cir.yield %[[#C]] : !s32i + // CHECK-NEXT: }, false { + // CHECK-NEXT: %[[#D:]] = cir.load %[[#A]] : !cir.ptr, !s32i + // CHECK-NEXT: cir.yield %[[#D]] : !s32i + // CHECK-NEXT: }) : (!cir.bool) -> !s32i + // CHECK-NEXT: cir.return %[[#B]] : !s32i + // CHECK-NEXT: } +} diff --git a/clang/test/CIR/Transforms/ternary-fold.cpp b/clang/test/CIR/Transforms/ternary-fold.cpp new file mode 100644 index 000000000000..5f37a8a36b95 --- /dev/null +++ b/clang/test/CIR/Transforms/ternary-fold.cpp @@ -0,0 +1,56 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir -mmlir --mlir-print-ir-before=cir-simplify %s -o %t1.cir 2>&1 | FileCheck -check-prefix=CIR-BEFORE %s +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir -mmlir --mlir-print-ir-after=cir-simplify %s -o %t2.cir 2>&1 | FileCheck -check-prefix=CIR-AFTER %s +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t.ll +// RUN: FileCheck --input-file=%t.ll --check-prefix=LLVM %s + +int test(bool x) { + return x ? 1 : 2; +} + +// CIR-BEFORE: cir.func @_Z4testb +// CIR-BEFORE: %{{.+}} = cir.ternary(%{{.+}}, true { +// CIR-BEFORE-NEXT: %[[#A:]] = cir.const #cir.int<1> : !s32i +// CIR-BEFORE-NEXT: cir.yield %[[#A]] : !s32i +// CIR-BEFORE-NEXT: }, false { +// CIR-BEFORE-NEXT: %[[#B:]] = cir.const #cir.int<2> : !s32i +// CIR-BEFORE-NEXT: cir.yield %[[#B]] : !s32i +// CIR-BEFORE-NEXT: }) : (!cir.bool) -> !s32i +// CIR-BEFORE: } + +// CIR-AFTER: cir.func @_Z4testb +// CIR-AFTER: %[[#A:]] = cir.const #cir.int<1> : !s32i +// CIR-AFTER-NEXT: %[[#B:]] = cir.const #cir.int<2> : !s32i +// CIR-AFTER-NEXT: %{{.+}} = cir.select if %{{.+}} then %[[#A]] else %[[#B]] : (!cir.bool, !s32i, !s32i) -> !s32i +// CIR-AFTER: } + +// LLVM: define dso_local i32 @_Z4testb +// LLVM: %{{.+}} = select i1 %{{.+}}, i32 1, i32 2 +// LLVM: } + +int test2(bool cond) { + constexpr int x = 1; + constexpr int y = 2; + return cond ? x : y; +} + +// CIR-BEFORE: cir.func @_Z5test2b +// CIR-BEFORE: %[[#COND:]] = cir.load %{{.+}} : !cir.ptr, !cir.bool +// CIR-BEFORE-NEXT: %{{.+}} = cir.ternary(%[[#COND]], true { +// CIR-BEFORE-NEXT: %[[#A:]] = cir.const #cir.int<1> : !s32i +// CIR-BEFORE-NEXT: cir.yield %[[#A]] : !s32i +// CIR-BEFORE-NEXT: }, false { +// CIR-BEFORE-NEXT: %[[#B:]] = cir.const #cir.int<2> : !s32i +// CIR-BEFORE-NEXT: cir.yield %[[#B]] : !s32i +// CIR-BEFORE-NEXT: }) : (!cir.bool) -> !s32i +// CIR-BEFORE: } + +// CIR-AFTER: cir.func @_Z5test2b +// CIR-AFTER: %[[#COND:]] = cir.load %{{.+}} : !cir.ptr, !cir.bool +// CIR-AFTER-NEXT: %[[#A:]] = cir.const #cir.int<1> : !s32i +// CIR-AFTER-NEXT: %[[#B:]] = cir.const #cir.int<2> : !s32i +// CIR-AFTER-NEXT: %{{.+}} = cir.select if %[[#COND]] then %[[#A]] else %[[#B]] : (!cir.bool, !s32i, !s32i) -> !s32i +// CIR-AFTER: } + +// LLVM: define dso_local i32 @_Z5test2b +// LLVM: %{{.+}} = select i1 %{{.+}}, i32 1, i32 2 +// LLVM: }