Skip to content

Commit e9cb582

Browse files
author
maxbartel
authored
[mlir][TOSA] Fix shape inference when operand was inferred (#66906)
057fc8e Introduces a bug in the `TosaInferShapesPass` when an operand type was already inferred. https://github.com/llvm/llvm-project/blob/f7bfa583b7a5ff0e9954d2810006b7a71123be88/mlir/include/mlir/Interfaces/InferTypeOpInterface.td#L248 interprets the `ValueShapeRange` as a normal `ValueRange` and looses the information of the inference. This PR changes the logic of the shape inference a bit. Instead of saving the type information in a `DenseMap` and updating the types after the whole analysis for a region, it now updates the types directly in each iteration. That way the operands always have the inferred type.
1 parent 2d27bf2 commit e9cb582

File tree

3 files changed

+29
-42
lines changed

3 files changed

+29
-42
lines changed

mlir/include/mlir/Interfaces/InferTypeOpInterface.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def InferTypeOpAdaptorWithIsCompatible : InferTypeOpAdaptorBase<
223223
>;
224224

225225
// Convenient trait to define a wrapper to inferReturnTypeComponents that passes
226-
// in the Op Adaptor directly
226+
// in the Op Adaptor directly. Only uses the current types of the operands.
227227
class InferShapedTypeOpAdaptorBase<list<string> overridenMethods = []> : TraitList<
228228
[
229229
// Op implements infer type op interface.

mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp

Lines changed: 15 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ namespace {
4141

4242
void propagateShapesInRegion(Region &region);
4343

44-
void propagateShapesToTosaIf(
45-
Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) {
44+
void propagateShapesToTosaIf(Operation &op) {
4645
IfOp ifOp = dyn_cast<IfOp>(op);
4746
if (!ifOp)
4847
return;
@@ -53,12 +52,12 @@ void propagateShapesToTosaIf(
5352
return;
5453

5554
for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
56-
auto inferredTy = shapesStorage[op.getOperand(i)];
55+
auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
5756
auto blockArg = frontBlock.getArgument(i - 1);
5857
auto oldType = cast<ShapedType>(blockArg.getType());
5958

6059
if (inferredTy.hasRank()) {
61-
Type newType = oldType.clone(inferredTy.getDims());
60+
Type newType = oldType.clone(inferredTy.getShape());
6261
blockArg.setType(newType);
6362
}
6463
}
@@ -79,8 +78,7 @@ void propagateShapesToTosaIf(
7978
}
8079
}
8180

82-
void propagateShapesToTosaWhile(
83-
Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) {
81+
void propagateShapesToTosaWhile(Operation &op) {
8482
WhileOp whileOp = dyn_cast<WhileOp>(op);
8583
if (!whileOp)
8684
return;
@@ -91,9 +89,8 @@ void propagateShapesToTosaWhile(
9189
llvm::SmallVector<Type> argTypes;
9290
for (auto operand : op.getOperands()) {
9391
auto operandTy = cast<ShapedType>(operand.getType());
94-
auto shapedTypeComponent = shapesStorage[operand];
95-
if (shapedTypeComponent.hasRank()) {
96-
auto newTy = operandTy.clone(shapedTypeComponent.getDims());
92+
if (operandTy.hasRank()) {
93+
auto newTy = operandTy.clone(operandTy.getShape());
9794
argTypes.push_back(newTy);
9895
} else {
9996
argTypes.push_back(operand.getType());
@@ -187,21 +184,6 @@ void propagateShapesToTosaWhile(
187184
}
188185

189186
void propagateShapesInRegion(Region &region) {
190-
DenseMap<Value, ShapedTypeComponents> shapesStorage;
191-
auto setShapes = [&](Value val, Type t) {
192-
if (auto st = dyn_cast<ShapedType>(t))
193-
shapesStorage[val] = st;
194-
else
195-
shapesStorage[val] = t;
196-
};
197-
auto operandShape = [&](Value val) -> ShapeAdaptor {
198-
// Query the WIP mapping rather than the type if set.
199-
auto it = shapesStorage.find(val);
200-
if (it == shapesStorage.end())
201-
return nullptr;
202-
return it->second;
203-
};
204-
205187
// Check whether this use case is replaceable. We define an op as
206188
// being replaceable if it is used by a ReturnOp, a TosaOp, or an op with a
207189
// type-inference related interface.
@@ -217,8 +199,8 @@ void propagateShapesInRegion(Region &region) {
217199
if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
218200
continue;
219201

220-
propagateShapesToTosaIf(op, shapesStorage);
221-
propagateShapesToTosaWhile(op, shapesStorage);
202+
propagateShapesToTosaIf(op);
203+
propagateShapesToTosaWhile(op);
222204

223205
InferShapedTypeOpInterface shapeInterface =
224206
dyn_cast<InferShapedTypeOpInterface>(op);
@@ -227,12 +209,11 @@ void propagateShapesInRegion(Region &region) {
227209

228210
SmallVector<ShapedTypeComponents> returnedShapes;
229211

230-
ValueShapeRange range(op.getOperands(), operandShape);
231212
if (shapeInterface
232-
.inferReturnTypeComponents(op.getContext(), op.getLoc(), range,
233-
op.getDiscardableAttrDictionary(),
234-
op.getPropertiesStorage(),
235-
op.getRegions(), returnedShapes)
213+
.inferReturnTypeComponents(
214+
op.getContext(), op.getLoc(), op.getOperands(),
215+
op.getDiscardableAttrDictionary(), op.getPropertiesStorage(),
216+
op.getRegions(), returnedShapes)
236217
.succeeded()) {
237218
for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
238219
Value result = std::get<0>(it);
@@ -262,20 +243,13 @@ void propagateShapesInRegion(Region &region) {
262243
ValueKnowledge::join(currentKnowledge, inferredKnowledge);
263244
if (!newKnowledge)
264245
continue;
265-
setShapes(result, newKnowledge.getType());
246+
247+
// Set new type
248+
result.setType(newKnowledge.getType());
266249
}
267250
}
268251
}
269252
}
270-
271-
// Actually update types with updated shape knowledge.
272-
for (auto it : shapesStorage) {
273-
auto result = it.second;
274-
if (result.hasRank()) {
275-
Type t = cast<ShapedType>(it.first.getType()).clone(result.getDims());
276-
it.first.setType(t);
277-
}
278-
}
279253
}
280254

281255
/// Pass that performs shape propagation across TOSA operations. This includes

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,3 +1259,16 @@ func.func @test_non_tosa_consumer_extract(%arg0: tensor<4x4xf32>, %arg1: index)
12591259
%1 = tensor.extract %0[%arg1, %arg1] : tensor<?x?xf32>
12601260
return %1 : f32
12611261
}
1262+
1263+
// -----
1264+
1265+
// CHECK-LABEL: test_tosa_use_def_chain
1266+
func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<?x16x16x16xf32> {
1267+
// CHECK: [[CONV:%.+]] = tosa.conv2d %arg0, %arg1, %arg2
1268+
// CHECK: (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
1269+
%0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<?x32x32x16xf32>
1270+
// CHECK: tosa.max_pool2d [[CONV]]
1271+
// CHECK: (tensor<1x32x32x16xf32>) -> tensor<1x16x16x16xf32>
1272+
%1 = tosa.max_pool2d %0 {kernel = array<i64: 2, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>} : (tensor<?x32x32x16xf32>) -> tensor<?x16x16x16xf32>
1273+
return %1 : tensor<?x16x16x16xf32>
1274+
}

0 commit comments

Comments
 (0)