Skip to content

[mlir][tosa] Don't fold mul with zero lhs/rhs if resulting type is dynamic #153420

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 13, 2025

Conversation

sahas3
Copy link
Member

@sahas3 sahas3 commented Aug 13, 2025

Canonicalizing the following IR:

func.func @mul_zero_dynamic_nofold(%arg0: tensor<?x17xf32>) -> tensor<?x17xf32> {
  %0 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
  %1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
  %2 = tosa.mul %arg0, %0, %1 : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
  return %2 : tensor<?x17xf32>
}

resulted in a crash

#0 0x000056513187e8db backtrace (./build-release/bin/mlir-opt+0x9d698db)                                                                                                                                                                                                                                                                                                                   
 #1 0x0000565131b17737 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/llvm/lib/Support/Unix/Signals.inc:838:8                                                                                                                                                                                                                
 #2 0x0000565131b187f3 PrintStackTraceSignalHandler(void*) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/llvm/lib/Support/Unix/Signals.inc:918:1                                                                                                                                                                                                                                
 #3 0x0000565131b18c30 llvm::sys::RunSignalHandlers() /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/llvm/lib/Support/Signals.cpp:105:18                                                                                                                                                                                                                                         
 #4 0x0000565131b18c30 SignalHandler(int, siginfo_t*, void*) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/llvm/lib/Support/Unix/Signals.inc:409:3                                                                                                                                                                                                                              
 #5 0x00007f2e4165b050 (/lib/x86_64-linux-gnu/libc.so.6+0x3c050)                                                                                                                                                                                                                                                                                                                            
 #6 0x00007f2e416a9eec __pthread_kill_implementation ./nptl/pthread_kill.c:44:76                                                                                                                                                                                                                                                                                                            
 #7 0x00007f2e4165afb2 raise ./signal/../sysdeps/posix/raise.c:27:6                                                                                                                                                                                                                                                                                                                         
 #8 0x00007f2e41645472 abort ./stdlib/abort.c:81:7                                                                                                                                                                                                                                                                                                                                          
 #9 0x00007f2e41645395 _nl_load_domain ./intl/loadmsgcat.c:1177:9                                                                                                                                                                                                                                                                                                                           
#10 0x00007f2e41653ec2 (/lib/x86_64-linux-gnu/libc.so.6+0x34ec2)                                                                                                                                                                                                                                                                                                                            
#11 0x00005651443ec4ba mlir::DenseIntOrFPElementsAttr::getRaw(mlir::ShapedType, llvm::ArrayRef<char>) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/mlir/lib/IR/BuiltinAttributes.cpp:1361:3                                                                                                                                                                                    
#12 0x00005651443f1209 mlir::DenseElementsAttr::resizeSplat(mlir::ShapedType) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/mlir/lib/IR/BuiltinAttributes.cpp:0:10                                                                                                                                                                                                              
#13 0x000056513f76f2b6 mlir::tosa::MulOp::fold(mlir::tosa::MulOpGenericAdaptor<llvm::ArrayRef<mlir::Attribute>>) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp:0:0

from the folder for tosa::mul since the zero value was being reshaped to ?x17 size which isn't supported. AFAIK, tosa.const requires all dimensions to be static. So in this case, the fix is to not to fold the op.

@sahas3 sahas3 marked this pull request as ready for review August 13, 2025 14:56
@sahas3 sahas3 requested a review from sjarus August 13, 2025 14:56
@llvmbot
Copy link
Member

llvmbot commented Aug 13, 2025

@llvm/pr-subscribers-mlir

Author: Sayan Saha (sahas3)

Changes

Canonicalizing the following IR:

func.func @<!-- -->mul_zero_dynamic_nofold(%arg0: tensor&lt;?x17xf32&gt;) -&gt; tensor&lt;?x17xf32&gt; {
  %0 = "tosa.const"() &lt;{values = dense&lt;0.000000e+00&gt; : tensor&lt;1x1xf32&gt;}&gt; : () -&gt; tensor&lt;1x1xf32&gt;
  %1 = "tosa.const"() &lt;{values = dense&lt;0&gt; : tensor&lt;1xi8&gt;}&gt; : () -&gt; tensor&lt;1xi8&gt;
  %2 = tosa.mul %arg0, %0, %1 : (tensor&lt;?x17xf32&gt;, tensor&lt;1x1xf32&gt;, tensor&lt;1xi8&gt;) -&gt; tensor&lt;?x17xf32&gt;
  return %2 : tensor&lt;?x17xf32&gt;
}

resulted in a crash

#<!-- -->0 0x000056513187e8db backtrace (./build-release/bin/mlir-opt+0x9d698db)                                                                                                                                                                                                                                                                                                                   
 #<!-- -->1 0x0000565131b17737 llvm::sys::PrintStackTrace(llvm::raw_ostream&amp;, int) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/llvm/lib/Support/Unix/Signals.inc:838:8                                                                                                                                                                                                                
 #<!-- -->2 0x0000565131b187f3 PrintStackTraceSignalHandler(void*) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/llvm/lib/Support/Unix/Signals.inc:918:1                                                                                                                                                                                                                                
 #<!-- -->3 0x0000565131b18c30 llvm::sys::RunSignalHandlers() /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/llvm/lib/Support/Signals.cpp:105:18                                                                                                                                                                                                                                         
 #<!-- -->4 0x0000565131b18c30 SignalHandler(int, siginfo_t*, void*) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/llvm/lib/Support/Unix/Signals.inc:409:3                                                                                                                                                                                                                              
 #<!-- -->5 0x00007f2e4165b050 (/lib/x86_64-linux-gnu/libc.so.6+0x3c050)                                                                                                                                                                                                                                                                                                                            
 #<!-- -->6 0x00007f2e416a9eec __pthread_kill_implementation ./nptl/pthread_kill.c:44:76                                                                                                                                                                                                                                                                                                            
 #<!-- -->7 0x00007f2e4165afb2 raise ./signal/../sysdeps/posix/raise.c:27:6                                                                                                                                                                                                                                                                                                                         
 #<!-- -->8 0x00007f2e41645472 abort ./stdlib/abort.c:81:7                                                                                                                                                                                                                                                                                                                                          
 #<!-- -->9 0x00007f2e41645395 _nl_load_domain ./intl/loadmsgcat.c:1177:9                                                                                                                                                                                                                                                                                                                           
#<!-- -->10 0x00007f2e41653ec2 (/lib/x86_64-linux-gnu/libc.so.6+0x34ec2)                                                                                                                                                                                                                                                                                                                            
#<!-- -->11 0x00005651443ec4ba mlir::DenseIntOrFPElementsAttr::getRaw(mlir::ShapedType, llvm::ArrayRef&lt;char&gt;) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/mlir/lib/IR/BuiltinAttributes.cpp:1361:3                                                                                                                                                                                    
#<!-- -->12 0x00005651443f1209 mlir::DenseElementsAttr::resizeSplat(mlir::ShapedType) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/mlir/lib/IR/BuiltinAttributes.cpp:0:10                                                                                                                                                                                                              
#<!-- -->13 0x000056513f76f2b6 mlir::tosa::MulOp::fold(mlir::tosa::MulOpGenericAdaptor&lt;llvm::ArrayRef&lt;mlir::Attribute&gt;&gt;) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp:0:0

from the folder for tosa::mul since the zero value was being reshaped to ?x17 size which isn't supported. AFAIK, tosa.const requires all dimensions to be static. So in this case, the fix is to not to fold the op.


Full diff: https://github.com/llvm/llvm-project/pull/153420.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+3-2)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+27)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e3cba38871909..fce61f27ca3ea 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1120,13 +1120,14 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
   }
 
   if (rhsTy == resultTy) {
-    if (isSplatZero(resultETy, lhsAttr))
+    if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
+      // constant values can only be resized if resulting type is static
       return lhsAttr.resizeSplat(resultTy);
     if (isSplatOne(resultETy, lhsAttr, shift))
       return rhs;
   }
   if (lhsTy == resultTy) {
-    if (isSplatZero(resultETy, rhsAttr))
+    if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
       return rhsAttr.resizeSplat(resultTy);
     if (isSplatOne(resultETy, rhsAttr, shift))
       return lhs;
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 5150ee36e9e5e..930bb9fe96811 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -565,6 +565,33 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
 
 // -----
 
+// CHECK-LABEL: @mul_zero_dynamic_nofold
+// CHECK-SAME:                    %[[ARG0:.*]]: tensor<?x17xf32>) -> tensor<?x17xf32> {
+// CHECK:           %[[ZERO:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
+// CHECK:           %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK:           %[[MUL:.*]] = tosa.mul %[[ARG0]], %[[ZERO]], %[[SHIFT]] : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
+// CHECK:           return %[[MUL]]
+func.func @mul_zero_dynamic_nofold(%arg0: tensor<?x17xf32>) -> tensor<?x17xf32> {
+  %0 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
+  %1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %2 = tosa.mul %arg0, %0, %1 : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
+  return %2 : tensor<?x17xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @mul_one_dynamic_fold
+// CHECK-SAME:                    %[[ARG0:.*]]: tensor<?x17xf32>) -> tensor<?x17xf32> {
+// CHECK:           return %[[ARG0]]
+func.func @mul_one_dynamic_fold(%arg0: tensor<?x17xf32>) -> tensor<?x17xf32> {
+  %0 = "tosa.const"() <{values = dense<1.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
+  %1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %2 = tosa.mul %arg0, %0, %1 : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
+  return %2 : tensor<?x17xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @select_same_value
 func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
   %0 = tosa.select %arg0, %arg1, %arg1 : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>

@llvmbot
Copy link
Member

llvmbot commented Aug 13, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Sayan Saha (sahas3)

Changes

Canonicalizing the following IR:

func.func @<!-- -->mul_zero_dynamic_nofold(%arg0: tensor&lt;?x17xf32&gt;) -&gt; tensor&lt;?x17xf32&gt; {
  %0 = "tosa.const"() &lt;{values = dense&lt;0.000000e+00&gt; : tensor&lt;1x1xf32&gt;}&gt; : () -&gt; tensor&lt;1x1xf32&gt;
  %1 = "tosa.const"() &lt;{values = dense&lt;0&gt; : tensor&lt;1xi8&gt;}&gt; : () -&gt; tensor&lt;1xi8&gt;
  %2 = tosa.mul %arg0, %0, %1 : (tensor&lt;?x17xf32&gt;, tensor&lt;1x1xf32&gt;, tensor&lt;1xi8&gt;) -&gt; tensor&lt;?x17xf32&gt;
  return %2 : tensor&lt;?x17xf32&gt;
}

resulted in a crash

#<!-- -->0 0x000056513187e8db backtrace (./build-release/bin/mlir-opt+0x9d698db)                                                                                                                                                                                                                                                                                                                   
 #<!-- -->1 0x0000565131b17737 llvm::sys::PrintStackTrace(llvm::raw_ostream&amp;, int) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/llvm/lib/Support/Unix/Signals.inc:838:8                                                                                                                                                                                                                
 #<!-- -->2 0x0000565131b187f3 PrintStackTraceSignalHandler(void*) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/llvm/lib/Support/Unix/Signals.inc:918:1                                                                                                                                                                                                                                
 #<!-- -->3 0x0000565131b18c30 llvm::sys::RunSignalHandlers() /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/llvm/lib/Support/Signals.cpp:105:18                                                                                                                                                                                                                                         
 #<!-- -->4 0x0000565131b18c30 SignalHandler(int, siginfo_t*, void*) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/llvm/lib/Support/Unix/Signals.inc:409:3                                                                                                                                                                                                                              
 #<!-- -->5 0x00007f2e4165b050 (/lib/x86_64-linux-gnu/libc.so.6+0x3c050)                                                                                                                                                                                                                                                                                                                            
 #<!-- -->6 0x00007f2e416a9eec __pthread_kill_implementation ./nptl/pthread_kill.c:44:76                                                                                                                                                                                                                                                                                                            
 #<!-- -->7 0x00007f2e4165afb2 raise ./signal/../sysdeps/posix/raise.c:27:6                                                                                                                                                                                                                                                                                                                         
 #<!-- -->8 0x00007f2e41645472 abort ./stdlib/abort.c:81:7                                                                                                                                                                                                                                                                                                                                          
 #<!-- -->9 0x00007f2e41645395 _nl_load_domain ./intl/loadmsgcat.c:1177:9                                                                                                                                                                                                                                                                                                                           
#<!-- -->10 0x00007f2e41653ec2 (/lib/x86_64-linux-gnu/libc.so.6+0x34ec2)                                                                                                                                                                                                                                                                                                                            
#<!-- -->11 0x00005651443ec4ba mlir::DenseIntOrFPElementsAttr::getRaw(mlir::ShapedType, llvm::ArrayRef&lt;char&gt;) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/mlir/lib/IR/BuiltinAttributes.cpp:1361:3                                                                                                                                                                                    
#<!-- -->12 0x00005651443f1209 mlir::DenseElementsAttr::resizeSplat(mlir::ShapedType) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/mlir/lib/IR/BuiltinAttributes.cpp:0:10                                                                                                                                                                                                              
#<!-- -->13 0x000056513f76f2b6 mlir::tosa::MulOp::fold(mlir::tosa::MulOpGenericAdaptor&lt;llvm::ArrayRef&lt;mlir::Attribute&gt;&gt;) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp:0:0

from the folder for tosa::mul since the zero value was being reshaped to ?x17 size which isn't supported. AFAIK, tosa.const requires all dimensions to be static. So in this case, the fix is to not to fold the op.


Full diff: https://github.com/llvm/llvm-project/pull/153420.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+3-2)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+27)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e3cba38871909..fce61f27ca3ea 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1120,13 +1120,14 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
   }
 
   if (rhsTy == resultTy) {
-    if (isSplatZero(resultETy, lhsAttr))
+    if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
+      // constant values can only be resized if resulting type is static
       return lhsAttr.resizeSplat(resultTy);
     if (isSplatOne(resultETy, lhsAttr, shift))
       return rhs;
   }
   if (lhsTy == resultTy) {
-    if (isSplatZero(resultETy, rhsAttr))
+    if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
       return rhsAttr.resizeSplat(resultTy);
     if (isSplatOne(resultETy, rhsAttr, shift))
       return lhs;
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 5150ee36e9e5e..930bb9fe96811 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -565,6 +565,33 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
 
 // -----
 
+// CHECK-LABEL: @mul_zero_dynamic_nofold
+// CHECK-SAME:                    %[[ARG0:.*]]: tensor<?x17xf32>) -> tensor<?x17xf32> {
+// CHECK:           %[[ZERO:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
+// CHECK:           %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK:           %[[MUL:.*]] = tosa.mul %[[ARG0]], %[[ZERO]], %[[SHIFT]] : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
+// CHECK:           return %[[MUL]]
+func.func @mul_zero_dynamic_nofold(%arg0: tensor<?x17xf32>) -> tensor<?x17xf32> {
+  %0 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
+  %1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %2 = tosa.mul %arg0, %0, %1 : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
+  return %2 : tensor<?x17xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @mul_one_dynamic_fold
+// CHECK-SAME:                    %[[ARG0:.*]]: tensor<?x17xf32>) -> tensor<?x17xf32> {
+// CHECK:           return %[[ARG0]]
+func.func @mul_one_dynamic_fold(%arg0: tensor<?x17xf32>) -> tensor<?x17xf32> {
+  %0 = "tosa.const"() <{values = dense<1.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
+  %1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %2 = tosa.mul %arg0, %0, %1 : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
+  return %2 : tensor<?x17xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @select_same_value
 func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
   %0 = tosa.select %arg0, %arg1, %arg1 : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @sahas3!

@lhutton1 lhutton1 changed the title [tosa] : Don't fold mul with zero lhs/rhs if resulting type is dynamic. [mlir][tosa] Don't fold mul with zero lhs/rhs if resulting type is dynamic Aug 13, 2025
Copy link
Contributor

@sjarus sjarus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dynamic shape handling is nice, thanks!

@sahas3 sahas3 merged commit 8432f24 into llvm:main Aug 13, 2025
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants