Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clang/include/clang/CIR/CIRToCIRPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ mlir::LogicalResult runCIRToCIRPasses(
llvm::StringRef lifetimeOpts, bool enableIdiomRecognizer,
llvm::StringRef idiomRecognizerOpts, bool enableLibOpt,
llvm::StringRef libOptOpts, std::string &passOptParsingFailure,
bool enableCIRSimplify, bool flattenCIR, bool emitMLIR,
bool enableCIRSimplify, bool flattenCIR, bool throughMLIR,
bool enableCallConvLowering, bool enableMem2reg);

} // namespace cir
Expand Down
4 changes: 2 additions & 2 deletions clang/lib/CIR/CodeGen/CIRPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ mlir::LogicalResult runCIRToCIRPasses(
llvm::StringRef lifetimeOpts, bool enableIdiomRecognizer,
llvm::StringRef idiomRecognizerOpts, bool enableLibOpt,
llvm::StringRef libOptOpts, std::string &passOptParsingFailure,
bool enableCIRSimplify, bool flattenCIR, bool emitCore,
bool enableCIRSimplify, bool flattenCIR, bool throughMLIR,
bool enableCallConvLowering, bool enableMem2Reg) {

llvm::TimeTraceScope scope("CIR To CIR Passes");
Expand Down Expand Up @@ -81,7 +81,7 @@ mlir::LogicalResult runCIRToCIRPasses(
if (enableMem2Reg)
pm.addPass(mlir::createMem2Reg());

if (emitCore)
if (throughMLIR)
pm.addPass(mlir::createSCFPreparePass());

// FIXME: once CIRCodenAction fixes emission other than CIR we
Expand Down
8 changes: 3 additions & 5 deletions clang/lib/CIR/FrontendAction/CIRGenAction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,18 +210,16 @@ class CIRGenConsumer : public clang::ASTConsumer {
action == CIRGenAction::OutputType::EmitMLIR &&
feOptions.MLIRTargetDialect == clang::frontend::MLIR_CIR_FLAT;

bool emitCore = action == CIRGenAction::OutputType::EmitMLIR &&
feOptions.MLIRTargetDialect == clang::frontend::MLIR_CORE;

// Setup and run CIR pipeline.
std::string passOptParsingFailure;
if (runCIRToCIRPasses(
mlirMod, mlirCtx.get(), C, !feOptions.ClangIRDisableCIRVerifier,
feOptions.ClangIRLifetimeCheck, lifetimeOpts,
feOptions.ClangIRIdiomRecognizer, idiomRecognizerOpts,
feOptions.ClangIRLibOpt, libOptOpts, passOptParsingFailure,
codeGenOptions.OptimizationLevel > 0, flattenCIR, emitCore,
enableCCLowering, feOptions.ClangIREnableMem2Reg)
codeGenOptions.OptimizationLevel > 0, flattenCIR,
!feOptions.ClangIRDirectLowering, enableCCLowering,
feOptions.ClangIREnableMem2Reg)
.failed()) {
if (!passOptParsingFailure.empty())
diagnosticsEngine.Report(diag::err_drv_cir_pass_opt_parsing)
Expand Down
33 changes: 26 additions & 7 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "clang/CIR/LowerToMLIR.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Module.h"

using namespace cir;
using namespace llvm;
Expand Down Expand Up @@ -252,6 +255,14 @@ void SCFLoop::analysis() {
if (!canonical)
return;

// If the IV is defined before the forOp (i.e. outside the surrounding
// cir.scope) this is not a canonical loop as the IV would not have the
// correct value after the forOp
if (ivAddr.getDefiningOp()->getBlock() != forOp->getBlock()) {
canonical = false;
return;
}

cmpOp = findCmpOp();
if (!cmpOp) {
canonical = false;
Expand Down Expand Up @@ -303,16 +314,24 @@ void SCFLoop::transferToSCFForOp() {
"Not support lowering loop with break, continue or if yet");
// Replace the IV usage to scf loop induction variable.
if (isIVLoad(op, ivAddr)) {
// Replace CIR IV load with arith.addi scf.IV, 0.
// The replacement makes the SCF IV can be automatically propogated
// by OpAdaptor for individual IV user lowering.
// The redundant arith.addi can be removed by later MLIR passes.
rewriter->setInsertionPoint(op);
auto newIV = plusConstant(scfForOp.getInductionVar(), loc, 0);
rewriter->replaceOp(op, newIV.getDefiningOp());
// Replace CIR IV load with scf.IV
// (i.e. remove the load op and replace the uses of the result of the CIR
// IV load with the scf.IV)
rewriter->replaceOp(op, scfForOp.getInductionVar());
}
return mlir::WalkResult::advance();
});

// All uses have been replaced by the scf.IV and we can remove the alloca +
// initial store operations

// The operations before the loop have been transferred to MLIR.
// So we need to go through getRemappedValue to find the operations.
auto remapAddr = rewriter->getRemappedValue(ivAddr);

// Since this is a canonical loop we can remove the alloca + initial store op
rewriter->eraseOp(remapAddr.getDefiningOp());
rewriter->eraseOp(*remapAddr.user_begin());
}

void SCFLoop::transformToSCFWhileOp() {
Expand Down
54 changes: 40 additions & 14 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
Expand All @@ -36,6 +37,7 @@
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
Expand All @@ -51,6 +53,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/TimeProfiler.h"

using namespace cir;
Expand Down Expand Up @@ -204,18 +207,41 @@ static bool findBaseAndIndices(mlir::Value addr, mlir::Value &base,
return true;
}

// For memref.reinterpret_cast has multiple users, erasing the operation
// after the last load or store been generated.
// If the memref.reinterpret_cast has multiple users (i.e the original
// cir.ptr_stride op has multiple users), only erase the operation after the
// last load or store has been generated.
static void eraseIfSafe(mlir::Value oldAddr, mlir::Value newAddr,
SmallVector<mlir::Operation *> &eraseList,
mlir::ConversionPatternRewriter &rewriter) {

unsigned oldUsedNum =
std::distance(oldAddr.getUses().begin(), oldAddr.getUses().end());
unsigned newUsedNum = 0;
// Count the uses of the newAddr (the result of the original base alloca) in
// load/store ops using an forwarded offset from the current
// memref.reinterpret_cast op
for (auto *user : newAddr.getUsers()) {
if (isa<mlir::memref::LoadOp>(*user) || isa<mlir::memref::StoreOp>(*user))
++newUsedNum;
if (auto loadOpUser = mlir::dyn_cast_or_null<mlir::memref::LoadOp>(*user)) {
if (!loadOpUser.getIndices().empty()) {
auto strideVal = loadOpUser.getIndices()[0];
if (strideVal ==
mlir::dyn_cast<mlir::memref::ReinterpretCastOp>(eraseList.back())
.getOffsets()[0])
++newUsedNum;
}
} else if (auto storeOpUser =
mlir::dyn_cast_or_null<mlir::memref::StoreOp>(*user)) {
if (!storeOpUser.getIndices().empty()) {
auto strideVal = storeOpUser.getIndices()[0];
if (strideVal ==
mlir::dyn_cast<mlir::memref::ReinterpretCastOp>(eraseList.back())
.getOffsets()[0])
++newUsedNum;
}
}
}
// If all load/store ops using forwarded offsets from the current
// memref.reinterpret_cast ops erase the memref.reinterpret_cast ops
if (oldUsedNum == newUsedNum) {
for (auto op : eraseList)
rewriter.eraseOp(op);
Expand All @@ -235,13 +261,13 @@ class CIRLoadOpLowering : public mlir::OpConversionPattern<cir::LoadOp> {
mlir::memref::LoadOp newLoad;
if (findBaseAndIndices(adaptor.getAddr(), base, indices, eraseList,
rewriter)) {
newLoad =
rewriter.create<mlir::memref::LoadOp>(op.getLoc(), base, indices);
// rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, base, indices);
newLoad = rewriter.create<mlir::memref::LoadOp>(
op.getLoc(), base, indices, op.getIsNontemporal());
eraseIfSafe(op.getAddr(), base, eraseList, rewriter);
} else
newLoad =
rewriter.create<mlir::memref::LoadOp>(op.getLoc(), adaptor.getAddr());
newLoad = rewriter.create<mlir::memref::LoadOp>(
op.getLoc(), adaptor.getAddr(), mlir::ValueRange{},
op.getIsNontemporal());

// Convert adapted result to its original type if needed.
mlir::Value result = emitFromMemory(rewriter, op, newLoad.getResult());
Expand All @@ -265,12 +291,13 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
mlir::Value value = emitToMemory(rewriter, op, adaptor.getValue());
if (findBaseAndIndices(adaptor.getAddr(), base, indices, eraseList,
rewriter)) {
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, value, base,
indices);
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(
op, value, base, indices, op.getIsNontemporal());
eraseIfSafe(op.getAddr(), base, eraseList, rewriter);
} else
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, value,
adaptor.getAddr());
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(
op, value, adaptor.getAddr(), mlir::ValueRange{},
op.getIsNontemporal());
return mlir::LogicalResult::success();
}
};
Expand Down Expand Up @@ -1451,7 +1478,6 @@ mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp theModule,
if (!result)
report_fatal_error(
"The pass manager failed to lower CIR to MLIR standard dialects!");

// Now that we ran all the lowering passes, verify the final output.
if (theModule.verify().failed())
report_fatal_error(
Expand Down
29 changes: 29 additions & 0 deletions clang/test/CIR/Lowering/ThroughMLIR/array.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,32 @@ int test_array2() {
int a[3][4];
return a[1][2];
}

int test_array3() {
// CIR-LABEL: cir.func {{.*}} @test_array3()
// CIR: %[[ARRAY:.*]] = cir.alloca !cir.array<!s32i x 3>, !cir.ptr<!cir.array<!s32i x 3>>, ["a"] {alignment = 4 : i64}
// CIR: %[[PTRDECAY1:.*]] = cir.cast(array_to_ptrdecay, %[[ARRAY]] : !cir.ptr<!cir.array<!s32i x 3>>), !cir.ptr<!s32i>
// CIR: %[[PTRSTRIDE1:.*]] = cir.ptr_stride(%[[PTRDECAY1]] : !cir.ptr<!s32i>, {{.*}} : !s32i), !cir.ptr<!s32i>
// CIR: {{.*}} = cir.load align(4) %[[PTRSTRIDE1]] : !cir.ptr<!s32i>, !s32i
// CIR: %[[PTRDECAY2:.*]] = cir.cast(array_to_ptrdecay, %[[ARRAY]] : !cir.ptr<!cir.array<!s32i x 3>>), !cir.ptr<!s32i>
// CIR: %[[PTRSTRIDE2:.*]] = cir.ptr_stride(%[[PTRDECAY2]] : !cir.ptr<!s32i>, {{.*}} : !s32i), !cir.ptr<!s32i>
// CIR: %{{.*}} = cir.load align(4) %[[PTRSTRIDE2]] : !cir.ptr<!s32i>, !s32i
// CIR: cir.store align(4) {{.*}}, %[[PTRSTRIDE2]] : !s32i, !cir.ptr<!s32i>
// CIR: %[[PTRDECAY3:.*]] = cir.cast(array_to_ptrdecay, %[[ARRAY]] : !cir.ptr<!cir.array<!s32i x 3>>), !cir.ptr<!s32i>
// CIR: %[[PTRSTRIDE3:.*]] = cir.ptr_stride(%[[PTRDECAY3]] : !cir.ptr<!s32i>, {{.*}} : !s32i), !cir.ptr<!s32i>
// CIR: %{{.*}} = cir.load align(4) %[[PTRSTRIDE3]] : !cir.ptr<!s32i>, !s32i

// MLIR-LABEL: func @test_array3
// MLIR: %{{.*}} = memref.alloca() {alignment = 4 : i64} : memref<i32>
// MLIR: %[[ARRAY:.*]] = memref.alloca() {alignment = 4 : i64} : memref<3xi32>
// MLIR: %[[IDX1:.*]] = arith.index_cast %{{.*}} : i32 to index
// MLIR: %{{.*}} = memref.load %[[ARRAY]][%[[IDX1]]] : memref<3xi32>
// MLIR: %[[IDX2:.*]] = arith.index_cast %{{.*}} : i32 to index
// MLIR: %{{.*}} = memref.load %[[ARRAY]][%[[IDX2]]] : memref<3xi32>
// MLIR: memref.store %{{.*}}, %[[ARRAY]][%[[IDX2]]] : memref<3xi32>
// MLIR: %[[IDX3:.*]] = arith.index_cast %{{.*}} : i32 to index
// MLIR: %{{.*}} = memref.load %[[ARRAY]][%[[IDX3]]] : memref<3xi32>
int a[3];
a[0] += a[2];
return a[1];
}
24 changes: 0 additions & 24 deletions clang/test/CIR/Lowering/ThroughMLIR/for-reject-1.cpp

This file was deleted.

25 changes: 0 additions & 25 deletions clang/test/CIR/Lowering/ThroughMLIR/for-reject-2.cpp

This file was deleted.

74 changes: 74 additions & 0 deletions clang/test/CIR/Lowering/ThroughMLIR/for-reject.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir
// RUN: FileCheck --input-file=%t.mlir %s

void f() {}

void reject_test1() {
for (int i = 0; i < 100; i++, f());
// CHECK: %[[ALLOCA:.+]] = memref.alloca
// CHECK: %[[ZERO:.+]] = arith.constant 0
// CHECK: memref.store %[[ZERO]], %[[ALLOCA]]
// CHECK: %[[HUNDRED:.+]] = arith.constant 100
// CHECK: scf.while : () -> () {
// CHECK: %[[TMP:.+]] = memref.load %[[ALLOCA]]
// CHECK: %[[TMP1:.+]] = arith.cmpi slt, %0, %[[HUNDRED]]
// CHECK: scf.condition(%[[TMP1]])
// CHECK: } do {
// CHECK: %[[TMP2:.+]] = memref.load %[[ALLOCA]]
// CHECK: %[[ONE:.+]] = arith.constant 1
// CHECK: %[[TMP3:.+]] = arith.addi %[[TMP2]], %[[ONE]]
// CHECK: memref.store %[[TMP3]], %[[ALLOCA]]
// CHECK: func.call @_Z1fv()
// CHECK: scf.yield
// CHECK: }
}

void reject_test2() {
for (int i = 0; i < 100; i++, i++);
// CHECK: %[[ALLOCA:.+]] = memref.alloca
// CHECK: %[[ZERO:.+]] = arith.constant 0
// CHECK: memref.store %[[ZERO]], %[[ALLOCA]]
// CHECK: %[[HUNDRED:.+]] = arith.constant 100
// CHECK: scf.while : () -> () {
// CHECK: %[[TMP:.+]] = memref.load %[[ALLOCA]]
// CHECK: %[[TMP2:.+]] = arith.cmpi slt, %[[TMP]], %[[HUNDRED]]
// CHECK: scf.condition(%[[TMP2]])
// CHECK: } do {
// CHECK: %[[TMP3:.+]] = memref.load %[[ALLOCA]]
// CHECK: %[[ONE:.+]] = arith.constant 1
// CHECK: %[[ADD:.+]] = arith.addi %[[TMP3]], %[[ONE]]
// CHECK: memref.store %[[ADD]], %[[ALLOCA]]
// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOCA]]
// CHECK: %[[ONE2:.+]] = arith.constant 1
// CHECK: %[[ADD2:.+]] = arith.addi %[[LOAD]], %[[ONE2]]
// CHECK: memref.store %[[ADD2]], %[[ALLOCA]]
// CHECK: scf.yield
// CHECK: }
}

void reject_test3() {
int i;
for (i = 0; i < 100; i++);
i += 10;
// CHECK: %[[ALLOCA:.+]] = memref.alloca()
// CHECK: memref.alloca_scope {
// CHECK: %[[ZERO:.+]] = arith.constant 0
// CHECK: memref.store %[[ZERO]], %[[ALLOCA]]
// CHECK: %[[HUNDRED:.+]] = arith.constant 100
// CHECK: scf.while : () -> () {
// CHECK: %[[TMP:.+]] = memref.load %[[ALLOCA]]
// CHECK: %[[TMP2:.+]] = arith.cmpi slt, %[[TMP]], %[[HUNDRED]]
// CHECK: scf.condition(%[[TMP2]])
// CHECK: } do {
// CHECK: %[[TMP3:.+]] = memref.load %[[ALLOCA]]
// CHECK: %[[ONE:.+]] = arith.constant 1
// CHECK: %[[ADD:.+]] = arith.addi %[[TMP3]], %[[ONE]]
// CHECK: memref.store %[[ADD]], %[[ALLOCA]]
// CHECK: scf.yield
// CHECK: }
// CHECK: }
// CHECK: %[[TEN:.+]] = arith.constant 10
// CHECK: %[[TMP4:.+]] = memref.load %[[ALLOCA]]
// CHECK: %[[TMP5:.+]] = arith.addi %[[TMP4]], %[[TEN]]
// CHECK: memref.store %[[TMP5]], %[[ALLOCA]]
}
Loading
Loading