Skip to content

Commit cc34ad9

Browse files
authored
[MLIR][OpenMP] Add cleanup region to omp.declare_reduction (#87377)
Currently, by-ref reductions will allocate the per-thread reduction variable in the initialization region. Adding a cleanup region allows that allocation to be undone. This will allow flang to support reduction of arrays stored on the heap. This conflation of allocation and initialization in the initialization should be fixed in the future to better match the OpenMP standard, but that is beyond the scope of this patch.
1 parent dbd6eb6 commit cc34ad9

File tree

7 files changed

+299
-24
lines changed

7 files changed

+299
-24
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2135,8 +2135,8 @@ def DeclareReductionOp : OpenMP_Op<"declare_reduction", [Symbol,
21352135
let summary = "declares a reduction kind";
21362136

21372137
let description = [{
2138-
Declares an OpenMP reduction kind. This requires two mandatory and one
2139-
optional region.
2138+
Declares an OpenMP reduction kind. This requires two mandatory and two
2139+
optional regions.
21402140

21412141
1. The initializer region specifies how to initialize the thread-local
21422142
reduction value. This is usually the neutral element of the reduction.
@@ -2149,6 +2149,10 @@ def DeclareReductionOp : OpenMP_Op<"declare_reduction", [Symbol,
21492149
3. The atomic reduction region is optional and specifies how two values
21502150
can be combined atomically given local accumulator variables. It is
21512151
expected to store the combined value in the first accumulator variable.
2152+
4. The cleanup region is optional and specifies how to clean up any memory
2153+
allocated by the initializer region. The region has an argument that
2154+
contains the value of the thread-local reduction accumulator. This will
2155+
be executed after the reduction has completed.
21522156

21532157
Note that the MLIR type system does not allow for type-polymorphic
21542158
reductions. Separate reduction declarations should be created for different
@@ -2163,12 +2167,14 @@ def DeclareReductionOp : OpenMP_Op<"declare_reduction", [Symbol,
21632167

21642168
let regions = (region AnyRegion:$initializerRegion,
21652169
AnyRegion:$reductionRegion,
2166-
AnyRegion:$atomicReductionRegion);
2170+
AnyRegion:$atomicReductionRegion,
2171+
AnyRegion:$cleanupRegion);
21672172

21682173
let assemblyFormat = "$sym_name `:` $type attr-dict-with-keyword "
21692174
"`init` $initializerRegion "
21702175
"`combiner` $reductionRegion "
2171-
"custom<AtomicReductionRegion>($atomicReductionRegion)";
2176+
"custom<AtomicReductionRegion>($atomicReductionRegion) "
2177+
"custom<CleanupReductionRegion>($cleanupRegion)";
21722178

21732179
let extraClassDeclaration = [{
21742180
PointerLikeType getAccumulatorType() {

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,6 +1538,21 @@ static void printAtomicReductionRegion(OpAsmPrinter &printer,
15381538
printer.printRegion(region);
15391539
}
15401540

1541+
static ParseResult parseCleanupReductionRegion(OpAsmParser &parser,
1542+
Region &region) {
1543+
if (parser.parseOptionalKeyword("cleanup"))
1544+
return success();
1545+
return parser.parseRegion(region);
1546+
}
1547+
1548+
static void printCleanupReductionRegion(OpAsmPrinter &printer,
1549+
DeclareReductionOp op, Region &region) {
1550+
if (region.empty())
1551+
return;
1552+
printer << "cleanup ";
1553+
printer.printRegion(region);
1554+
}
1555+
15411556
LogicalResult DeclareReductionOp::verifyRegions() {
15421557
if (getInitializerRegion().empty())
15431558
return emitOpError() << "expects non-empty initializer region";
@@ -1571,21 +1586,29 @@ LogicalResult DeclareReductionOp::verifyRegions() {
15711586
"of the reduction type";
15721587
}
15731588

1574-
if (getAtomicReductionRegion().empty())
1589+
if (!getAtomicReductionRegion().empty()) {
1590+
Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
1591+
if (atomicReductionEntryBlock.getNumArguments() != 2 ||
1592+
atomicReductionEntryBlock.getArgumentTypes()[0] !=
1593+
atomicReductionEntryBlock.getArgumentTypes()[1])
1594+
return emitOpError() << "expects atomic reduction region with two "
1595+
"arguments of the same type";
1596+
auto ptrType = llvm::dyn_cast<PointerLikeType>(
1597+
atomicReductionEntryBlock.getArgumentTypes()[0]);
1598+
if (!ptrType ||
1599+
(ptrType.getElementType() && ptrType.getElementType() != getType()))
1600+
return emitOpError() << "expects atomic reduction region arguments to "
1601+
"be accumulators containing the reduction type";
1602+
}
1603+
1604+
if (getCleanupRegion().empty())
15751605
return success();
1606+
Block &cleanupEntryBlock = getCleanupRegion().front();
1607+
if (cleanupEntryBlock.getNumArguments() != 1 ||
1608+
cleanupEntryBlock.getArgument(0).getType() != getType())
1609+
return emitOpError() << "expects cleanup region with one argument "
1610+
"of the reduction type";
15761611

1577-
Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
1578-
if (atomicReductionEntryBlock.getNumArguments() != 2 ||
1579-
atomicReductionEntryBlock.getArgumentTypes()[0] !=
1580-
atomicReductionEntryBlock.getArgumentTypes()[1])
1581-
return emitOpError() << "expects atomic reduction region with two "
1582-
"arguments of the same type";
1583-
auto ptrType = llvm::dyn_cast<PointerLikeType>(
1584-
atomicReductionEntryBlock.getArgumentTypes()[0]);
1585-
if (!ptrType ||
1586-
(ptrType.getElementType() && ptrType.getElementType() != getType()))
1587-
return emitOpError() << "expects atomic reduction region arguments to "
1588-
"be accumulators containing the reduction type";
15891612
return success();
15901613
}
15911614

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,32 @@ static void collectReductionInfo(
877877
}
878878
}
879879

880+
/// handling of DeclareReductionOp's cleanup region
881+
static LogicalResult inlineReductionCleanup(
882+
llvm::SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
883+
llvm::ArrayRef<llvm::Value *> privateReductionVariables,
884+
LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder) {
885+
for (auto [i, reductionDecl] : llvm::enumerate(reductionDecls)) {
886+
Region &cleanupRegion = reductionDecl.getCleanupRegion();
887+
if (cleanupRegion.empty())
888+
continue;
889+
890+
// map the argument to the cleanup region
891+
Block &entry = cleanupRegion.front();
892+
moduleTranslation.mapValue(entry.getArgument(0),
893+
privateReductionVariables[i]);
894+
895+
if (failed(inlineConvertOmpRegions(cleanupRegion, "omp.reduction.cleanup",
896+
builder, moduleTranslation)))
897+
return failure();
898+
899+
// clear block argument mapping in case it needs to be re-created with a
900+
// different source for another use of the same reduction decl
901+
moduleTranslation.forgetMapping(cleanupRegion);
902+
}
903+
return success();
904+
}
905+
880906
/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
881907
static LogicalResult
882908
convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -1072,7 +1098,9 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
10721098
tempTerminator->eraseFromParent();
10731099
builder.restoreIP(nextInsertionPoint);
10741100

1075-
return success();
1101+
// after the workshare loop, deallocate private reduction variables
1102+
return inlineReductionCleanup(reductionDecls, privateReductionVariables,
1103+
moduleTranslation, builder);
10761104
}
10771105

10781106
/// A RAII class that on construction replaces the region arguments of the
@@ -1125,13 +1153,13 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11251153
LogicalResult bodyGenStatus = success();
11261154
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
11271155

1128-
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1129-
// Collect reduction declarations
1130-
SmallVector<omp::DeclareReductionOp> reductionDecls;
1131-
collectReductionDecls(opInst, reductionDecls);
1156+
// Collect reduction declarations
1157+
SmallVector<omp::DeclareReductionOp> reductionDecls;
1158+
collectReductionDecls(opInst, reductionDecls);
1159+
SmallVector<llvm::Value *> privateReductionVariables;
11321160

1161+
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
11331162
// Allocate reduction vars
1134-
SmallVector<llvm::Value *> privateReductionVariables;
11351163
DenseMap<Value, llvm::Value *> reductionVariableMap;
11361164
if (!isByRef) {
11371165
allocByValReductionVars(opInst, builder, moduleTranslation, allocaIP,
@@ -1331,7 +1359,18 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
13311359

13321360
// TODO: Perform finalization actions for variables. This has to be
13331361
// called for variables which have destructors/finalizers.
1334-
auto finiCB = [&](InsertPointTy codeGenIP) {};
1362+
auto finiCB = [&](InsertPointTy codeGenIP) {
1363+
InsertPointTy oldIP = builder.saveIP();
1364+
builder.restoreIP(codeGenIP);
1365+
1366+
// if the reduction has a cleanup region, inline it here to finalize the
1367+
// reduction variables
1368+
if (failed(inlineReductionCleanup(reductionDecls, privateReductionVariables,
1369+
moduleTranslation, builder)))
1370+
bodyGenStatus = failure();
1371+
1372+
builder.restoreIP(oldIP);
1373+
};
13351374

13361375
llvm::Value *ifCond = nullptr;
13371376
if (auto ifExprVar = opInst.getIfExprVar())

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,25 @@ atomic {
436436

437437
// -----
438438

439+
// expected-error @below {{op expects cleanup region with one argument of the reduction type}}
440+
omp.declare_reduction @add_f32 : f32
441+
init {
442+
^bb0(%arg: f32):
443+
%0 = arith.constant 0.0 : f32
444+
omp.yield (%0 : f32)
445+
}
446+
combiner {
447+
^bb1(%arg0: f32, %arg1: f32):
448+
%1 = arith.addf %arg0, %arg1 : f32
449+
omp.yield (%1 : f32)
450+
}
451+
cleanup {
452+
^bb0(%arg: f64):
453+
omp.yield
454+
}
455+
456+
// -----
457+
439458
func.func @foo(%lb : index, %ub : index, %step : index) {
440459
%c1 = arith.constant 1 : i32
441460
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,8 @@ func.func @omp_target_pretty(%if_cond : i1, %device : si32, %num_threads : i32)
603603
// CHECK: atomic
604604
// CHECK: ^{{.+}}(%{{.+}}: !llvm.ptr, %{{.+}}: !llvm.ptr):
605605
// CHECK: omp.yield
606+
// CHECK: cleanup
607+
// CHECK: omp.yield
606608
omp.declare_reduction @add_f32 : f32
607609
init {
608610
^bb0(%arg: f32):
@@ -620,6 +622,10 @@ atomic {
620622
llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
621623
omp.yield
622624
}
625+
cleanup {
626+
^bb0(%arg: f32):
627+
omp.yield
628+
}
623629

624630
// CHECK-LABEL: func @wsloop_reduction
625631
func.func @wsloop_reduction(%lb : index, %ub : index, %step : index) {
@@ -789,6 +795,7 @@ combiner {
789795
omp.yield (%1 : f32)
790796
}
791797
// CHECK-NOT: atomic
798+
// CHECK-NOT: cleanup
792799

793800
// CHECK-LABEL: func @wsloop_reduction2
794801
func.func @wsloop_reduction2(%lb : index, %ub : index, %step : index) {
@@ -2088,6 +2095,7 @@ func.func @opaque_pointers_atomic_rwu(%v: !llvm.ptr, %x: !llvm.ptr) {
20882095
// CHECK-LABEL: @opaque_pointers_reduction
20892096
// CHECK: atomic {
20902097
// CHECK-NEXT: ^{{[[:alnum:]]+}}(%{{.*}}: !llvm.ptr, %{{.*}}: !llvm.ptr):
2098+
// CHECK-NOT: cleanup
20912099
omp.declare_reduction @opaque_pointers_reduction : f32
20922100
init {
20932101
^bb0(%arg: f32):
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
2+
3+
// test a parallel reduction with a cleanup region
4+
5+
omp.declare_reduction @add_reduction_i_32 : !llvm.ptr init {
6+
^bb0(%arg0: !llvm.ptr):
7+
%0 = llvm.mlir.constant(0 : i32) : i32
8+
%c4 = llvm.mlir.constant(4 : i64) : i64
9+
%2 = llvm.call @malloc(%c4) : (i64) -> !llvm.ptr
10+
llvm.store %0, %2 : i32, !llvm.ptr
11+
omp.yield(%2 : !llvm.ptr)
12+
} combiner {
13+
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
14+
%0 = llvm.load %arg0 : !llvm.ptr -> i32
15+
%1 = llvm.load %arg1 : !llvm.ptr -> i32
16+
%2 = llvm.add %0, %1 : i32
17+
llvm.store %2, %arg0 : i32, !llvm.ptr
18+
omp.yield(%arg0 : !llvm.ptr)
19+
} cleanup {
20+
^bb0(%arg0: !llvm.ptr):
21+
llvm.call @free(%arg0) : (!llvm.ptr) -> ()
22+
omp.yield
23+
}
24+
25+
// CHECK-LABEL: @main
26+
llvm.func @main() {
27+
%0 = llvm.mlir.constant(-1 : i32) : i32
28+
%1 = llvm.mlir.addressof @i : !llvm.ptr
29+
%2 = llvm.mlir.addressof @j : !llvm.ptr
30+
omp.parallel byref reduction(@add_reduction_i_32 %1 -> %arg0 : !llvm.ptr, @add_reduction_i_32 %2 -> %arg1 : !llvm.ptr) {
31+
llvm.store %0, %arg0 : i32, !llvm.ptr
32+
llvm.store %0, %arg1 : i32, !llvm.ptr
33+
omp.terminator
34+
}
35+
llvm.return
36+
}
37+
llvm.mlir.global internal @i() {addr_space = 0 : i32} : i32 {
38+
%0 = llvm.mlir.constant(0 : i32) : i32
39+
llvm.return %0 : i32
40+
}
41+
llvm.mlir.global internal @j() {addr_space = 0 : i32} : i32 {
42+
%0 = llvm.mlir.constant(0 : i32) : i32
43+
llvm.return %0 : i32
44+
}
45+
llvm.func @malloc(%arg0 : i64) -> !llvm.ptr
46+
llvm.func @free(%arg0 : !llvm.ptr) -> ()
47+
48+
// CHECK: %{{.+}} =
49+
// Call to the outlined function.
50+
// CHECK: call void {{.*}} @__kmpc_fork_call
51+
// CHECK-SAME: @[[OUTLINED:[A-Za-z_.][A-Za-z0-9_.]*]]
52+
53+
// Outlined function.
54+
// CHECK: define internal void @[[OUTLINED]]
55+
56+
// Private reduction variable and its initialization.
57+
// CHECK: %tid.addr.local = alloca i32
58+
// CHECK: %[[MALLOC_I:.+]] = call ptr @malloc(i64 4)
59+
// CHECK: %[[PRIV_PTR_I:.+]] = alloca ptr
60+
// CHECK: store ptr %[[MALLOC_I]], ptr %[[PRIV_PTR_I]]
61+
// CHECK: %[[MALLOC_J:.+]] = call ptr @malloc(i64 4)
62+
// CHECK: %[[PRIV_PTR_J:.+]] = alloca ptr
63+
// CHECK: store ptr %[[MALLOC_J]], ptr %[[PRIV_PTR_J]]
64+
65+
// Call to the reduction function.
66+
// CHECK: call i32 @__kmpc_reduce
67+
// CHECK-SAME: @[[REDFUNC:[A-Za-z_.][A-Za-z0-9_.]*]]
68+
69+
70+
// Non-atomic reduction:
71+
// CHECK: %[[PRIV_VAL_PTR_I:.+]] = load ptr, ptr %[[PRIV_PTR_I]]
72+
// CHECK: %[[LOAD_I:.+]] = load i32, ptr @i
73+
// CHECK: %[[PRIV_VAL_I:.+]] = load i32, ptr %[[PRIV_VAL_PTR_I]]
74+
// CHECK: %[[SUM_I:.+]] = add i32 %[[LOAD_I]], %[[PRIV_VAL_I]]
75+
// CHECK: store i32 %[[SUM_I]], ptr @i
76+
// CHECK: %[[PRIV_VAL_PTR_J:.+]] = load ptr, ptr %[[PRIV_PTR_J]]
77+
// CHECK: %[[LOAD_J:.+]] = load i32, ptr @j
78+
// CHECK: %[[PRIV_VAL_J:.+]] = load i32, ptr %[[PRIV_VAL_PTR_J]]
79+
// CHECK: %[[SUM_J:.+]] = add i32 %[[LOAD_J]], %[[PRIV_VAL_J]]
80+
// CHECK: store i32 %[[SUM_J]], ptr @j
81+
// CHECK: call void @__kmpc_end_reduce
82+
// CHECK: br label %[[FINALIZE:.+]]
83+
84+
// CHECK: [[FINALIZE]]:
85+
// CHECK: br label %[[OMP_FINALIZE:.+]]
86+
87+
// Cleanup region:
88+
// CHECK: [[OMP_FINALIZE]]:
89+
// CHECK: call void @free(ptr %[[PRIV_PTR_I]])
90+
// CHECK: call void @free(ptr %[[PRIV_PTR_J]])
91+
92+
// Reduction function.
93+
// CHECK: define internal void @[[REDFUNC]]
94+
// CHECK: add i32

0 commit comments

Comments
 (0)