@@ -363,11 +363,13 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
363363
364364// CHECK-LABEL: @conv2d_i8 
365365func.func  @conv2d_i8 (%input:  tensor <1 x49 x42 x27 xi8 >, %weights:  tensor <28 x1 x1 x27 xi8 >, %bias:  tensor <28 xi8 >) -> () {
366+   // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> 
367+   // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]] 
366368  // CHECK: %[[M_IN:.+]] = tensor.empty() 
367369  // CHECK: %[[CST:.+]] = arith.constant 0 
368370  // CHECK: %[[FILL:.+]] = linalg.fill 
369371  // CHECK: %[[B_IN:.+]] = tensor.empty() 
370-   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q  {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 , %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8 >, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32> 
372+   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q  {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] , %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8 >, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32> 
371373  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>) 
372374  // CHECK:   arith.extsi 
373375  // CHECK:   arith.addi 
@@ -383,11 +385,13 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
383385
384386// CHECK-LABEL: @conv2d_f32 
385387func.func  @conv2d_f32 (%input:  tensor <1 x49 x42 x27 xf32 >, %weights:  tensor <28 x3 x3 x27 xf32 >, %bias:  tensor <28 xf32 >) -> () {
388+   // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> 
389+   // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]] 
386390  // CHECK: %[[M_IN:.+]] = tensor.empty() 
387391  // CHECK: %[[CST:.+]] = arith.constant 0 
388392  // CHECK: %[[FILL:.+]] = linalg.fill 
389393  // CHECK: %[[B_IN:.+]] = tensor.empty() 
390-   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc  {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1  : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32 >) outs(%[[FILL]] : tensor<1x45x40x28xf32>) 
394+   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf  {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]]  : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32 >) outs(%[[FILL]] : tensor<1x45x40x28xf32>) 
391395  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>) 
392396  // CHECK:   arith.addf 
393397  // CHECK:   linalg.yield 
@@ -404,11 +408,13 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27
404408func.func  @conv2d_dyn (%input:  tensor <?x49 x42 x27 xf32 >, %weights:  tensor <28 x3 x3 x27 xf32 >, %bias:  tensor <28 xf32 >) -> () {
405409  // CHECK: %[[C0:.+]] = arith.constant 0 
406410  // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] 
411+   // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> 
412+   // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]] 
407413  // CHECK: %[[M_IN:.+]] = tensor.empty(%[[BATCH]]) 
408414  // CHECK: %[[CST:.+]] = arith.constant 0 
409415  // CHECK: %[[FILL:.+]] = linalg.fill 
410416  // CHECK: %[[B_IN:.+]] = tensor.empty(%[[BATCH]]) 
411-   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc  {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1  : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32 >) outs(%[[FILL]] : tensor<?x45x40x28xf32>) 
417+   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf  {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]]  : tensor<?x49x42x27xf32>, tensor<3x3x27x28xf32 >) outs(%[[FILL]] : tensor<?x45x40x28xf32>) 
412418  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<?x45x40x28xf32>) outs(%[[B_IN]] : tensor<?x45x40x28xf32>) 
413419  // CHECK:   %[[ADD:.+]] = arith.addf 
414420  // CHECK:   linalg.yield %[[ADD]] : f32 
@@ -462,11 +468,13 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
462468  // CHECK: %[[W_OUT:.+]] = arith.addi %[[DIVIDED_0]], %[[ONE_0]] : index 
463469
464470  // Running convolution 
471+   // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> 
472+   // CHECK: %[[WEIGHT:.+]] = tosa.transpose %arg1, %[[PERM]] 
465473  // CHECK: %[[M_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]]) 
466474  // CHECK: %[[CST:.+]] = arith.constant 0 
467475  // CHECK: %[[FILL:.+]] = linalg.fill 
468476  // CHECK: %[[B_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]]) 
469-   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc  {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1  : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32 >) outs(%[[FILL]] : tensor<1x?x?x28xf32>) 
477+   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf  {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[WEIGHT]]  : tensor<1x?x?x27xf32>, tensor<3x3x27x28xf32 >) outs(%[[FILL]] : tensor<1x?x?x28xf32>) 
470478  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x?x?x28xf32>) outs(%[[B_IN]] : tensor<1x?x?x28xf32>) 
471479  // CHECK:   %[[ADD:.+]] = arith.addf 
472480  // CHECK:   linalg.yield %[[ADD]] : f32 
@@ -481,7 +489,7 @@ func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28
481489  // CHECK: %[[C0:.+]] = arith.constant 0 
482490  // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] 
483491  // CHECK:   tensor.yield %[[C0]] 
484-   // CHECK: linalg.conv_2d_nhwc_fhwc  
492+   // CHECK: linalg.conv_2d_nhwc_hwcf  
485493  %0  = tosa.conv2d  %input , %weights , %bias  {pad  = array<i64 : 1 , 1 , 1 , 1 >, stride  = array<i64 : 1 , 1 >, dilation  = array<i64 : 2 , 1 >} : (tensor <1 x47 x40 x28 xf32 >, tensor <28 x3 x3 x28 xf32 >, tensor <28 xf32 >) -> tensor <1 x45 x40 x28 xf32 >
486494  return 
487495}
@@ -493,7 +501,7 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
493501  // CHECK:   %[[C22:.+]] = arith.constant -22 
494502  // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] 
495503  // CHECK:   tensor.yield %[[C22]] 
496-   // CHECK: linalg.conv_2d_nhwc_fhwc_q  
504+   // CHECK: linalg.conv_2d_nhwc_hwcf_q  
497505  %0  = tosa.conv2d  %arg0 , %arg1 , %arg2  {dilation  = array<i64 : 1 , 1 >, pad  = array<i64 : 1 , 1 , 1 , 1 >, quantization_info  = #tosa.conv_quant <input_zp  = -22 , weight_zp  = 42 >, stride  = array<i64 : 1 , 1 >} : (tensor <1 x12 x12 x1 xi8 >, tensor <1024 x3 x3 x1 xi8 >, tensor <1024 xi32 >) -> tensor <1 x12 x12 x1024 xi32 >
498506  return 
499507}
0 commit comments