@@ -41,8 +41,7 @@ namespace {
4141
4242void propagateShapesInRegion (Region ®ion);
4343
44- void propagateShapesToTosaIf (
45- Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) {
44+ void propagateShapesToTosaIf (Operation &op) {
4645 IfOp ifOp = dyn_cast<IfOp>(op);
4746 if (!ifOp)
4847 return ;
@@ -53,12 +52,12 @@ void propagateShapesToTosaIf(
5352 return ;
5453
5554 for (unsigned int i = 1 , s = op.getNumOperands (); i < s; i++) {
56- auto inferredTy = shapesStorage[ op.getOperand (i)] ;
55+ auto inferredTy = cast<ShapedType>( op.getOperand (i). getType ()) ;
5756 auto blockArg = frontBlock.getArgument (i - 1 );
5857 auto oldType = cast<ShapedType>(blockArg.getType ());
5958
6059 if (inferredTy.hasRank ()) {
61- Type newType = oldType.clone (inferredTy.getDims ());
60+ Type newType = oldType.clone (inferredTy.getShape ());
6261 blockArg.setType (newType);
6362 }
6463 }
@@ -79,8 +78,7 @@ void propagateShapesToTosaIf(
7978 }
8079}
8180
82- void propagateShapesToTosaWhile (
83- Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) {
81+ void propagateShapesToTosaWhile (Operation &op) {
8482 WhileOp whileOp = dyn_cast<WhileOp>(op);
8583 if (!whileOp)
8684 return ;
@@ -91,9 +89,8 @@ void propagateShapesToTosaWhile(
9189 llvm::SmallVector<Type> argTypes;
9290 for (auto operand : op.getOperands ()) {
9391 auto operandTy = cast<ShapedType>(operand.getType ());
94- auto shapedTypeComponent = shapesStorage[operand];
95- if (shapedTypeComponent.hasRank ()) {
96- auto newTy = operandTy.clone (shapedTypeComponent.getDims ());
92+ if (operandTy.hasRank ()) {
93+ auto newTy = operandTy.clone (operandTy.getShape ());
9794 argTypes.push_back (newTy);
9895 } else {
9996 argTypes.push_back (operand.getType ());
@@ -187,21 +184,6 @@ void propagateShapesToTosaWhile(
187184}
188185
189186void propagateShapesInRegion (Region ®ion) {
190- DenseMap<Value, ShapedTypeComponents> shapesStorage;
191- auto setShapes = [&](Value val, Type t) {
192- if (auto st = dyn_cast<ShapedType>(t))
193- shapesStorage[val] = st;
194- else
195- shapesStorage[val] = t;
196- };
197- auto operandShape = [&](Value val) -> ShapeAdaptor {
198- // Query the WIP mapping rather than the type if set.
199- auto it = shapesStorage.find (val);
200- if (it == shapesStorage.end ())
201- return nullptr ;
202- return it->second ;
203- };
204-
205187 // Check whether this use case is replaceable. We define an op as
206188 // being replaceable if it is used by a ReturnOp, a TosaOp, or an op with a
207189 // type-inference related interface.
@@ -217,8 +199,8 @@ void propagateShapesInRegion(Region ®ion) {
217199 if (op.getDialect ()->getNamespace () != TosaDialect::getDialectNamespace ())
218200 continue ;
219201
220- propagateShapesToTosaIf (op, shapesStorage );
221- propagateShapesToTosaWhile (op, shapesStorage );
202+ propagateShapesToTosaIf (op);
203+ propagateShapesToTosaWhile (op);
222204
223205 InferShapedTypeOpInterface shapeInterface =
224206 dyn_cast<InferShapedTypeOpInterface>(op);
@@ -227,12 +209,11 @@ void propagateShapesInRegion(Region ®ion) {
227209
228210 SmallVector<ShapedTypeComponents> returnedShapes;
229211
230- ValueShapeRange range (op.getOperands (), operandShape);
231212 if (shapeInterface
232- .inferReturnTypeComponents (op. getContext (), op. getLoc (), range,
233- op.getDiscardableAttrDictionary (),
234- op.getPropertiesStorage (),
235- op.getRegions (), returnedShapes)
213+ .inferReturnTypeComponents (
214+ op. getContext (), op. getLoc (), op.getOperands (),
215+ op. getDiscardableAttrDictionary (), op.getPropertiesStorage (),
216+ op.getRegions (), returnedShapes)
236217 .succeeded ()) {
237218 for (auto it : llvm::zip (op.getResults (), returnedShapes)) {
238219 Value result = std::get<0 >(it);
@@ -262,20 +243,13 @@ void propagateShapesInRegion(Region ®ion) {
262243 ValueKnowledge::join (currentKnowledge, inferredKnowledge);
263244 if (!newKnowledge)
264245 continue ;
265- setShapes (result, newKnowledge.getType ());
246+
247+ // Set new type
248+ result.setType (newKnowledge.getType ());
266249 }
267250 }
268251 }
269252 }
270-
271- // Actually update types with updated shape knowledge.
272- for (auto it : shapesStorage) {
273- auto result = it.second ;
274- if (result.hasRank ()) {
275- Type t = cast<ShapedType>(it.first .getType ()).clone (result.getDims ());
276- it.first .setType (t);
277- }
278- }
279253}
280254
281255// / Pass that performs shape propagation across TOSA operations. This includes
0 commit comments