diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index e3cba38871909..fce61f27ca3ea 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1120,13 +1120,14 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) { } if (rhsTy == resultTy) { - if (isSplatZero(resultETy, lhsAttr)) + if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape()) + // constant values can only be resized if resulting type is static return lhsAttr.resizeSplat(resultTy); if (isSplatOne(resultETy, lhsAttr, shift)) return rhs; } if (lhsTy == resultTy) { - if (isSplatZero(resultETy, rhsAttr)) + if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape()) return rhsAttr.resizeSplat(resultTy); if (isSplatOne(resultETy, rhsAttr, shift)) return lhs; diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 5150ee36e9e5e..930bb9fe96811 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -565,6 +565,33 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso // ----- +// CHECK-LABEL: @mul_zero_dynamic_nofold +// CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { +// CHECK: %[[ZERO:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> +// CHECK: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[MUL:.*]] = tosa.mul %[[ARG0]], %[[ZERO]], %[[SHIFT]] : (tensor, tensor<1x1xf32>, tensor<1xi8>) -> tensor +// CHECK: return %[[MUL]] +func.func @mul_zero_dynamic_nofold(%arg0: tensor) -> tensor { + %0 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> + %1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %2 = tosa.mul %arg0, %0, %1 : (tensor, tensor<1x1xf32>, tensor<1xi8>) -> tensor + return %2 : tensor +} + +// ----- + +// CHECK-LABEL: @mul_one_dynamic_fold +// CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { +// CHECK: return %[[ARG0]] +func.func @mul_one_dynamic_fold(%arg0: tensor) -> tensor { + %0 = "tosa.const"() <{values = dense<1.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> + %1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %2 = tosa.mul %arg0, %0, %1 : (tensor, tensor<1x1xf32>, tensor<1xi8>) -> tensor + return %2 : tensor +} + +// ----- + // CHECK-LABEL: @select_same_value func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { %0 = tosa.select %arg0, %arg1, %arg1 : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>