@@ -363,13 +363,11 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
363363
364364// CHECK-LABEL: @conv2d_i8 
365365func.func  @conv2d_i8 (%input:  tensor <1 x49 x42 x27 xi8 >, %weights:  tensor <28 x1 x1 x27 xi8 >, %bias:  tensor <28 xi8 >) -> () {
366-   // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> 
367-   // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]] 
368366  // CHECK: %[[M_IN:.+]] = tensor.empty() 
369367  // CHECK: %[[CST:.+]] = arith.constant 0 
370368  // CHECK: %[[FILL:.+]] = linalg.fill 
371369  // CHECK: %[[B_IN:.+]] = tensor.empty() 
372-   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q  {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] , %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8 >, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32> 
370+   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q  {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 , %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8 >, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32> 
373371  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>) 
374372  // CHECK:   arith.extsi 
375373  // CHECK:   arith.addi 
@@ -385,13 +383,11 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
385383
386384// CHECK-LABEL: @conv2d_f32 
387385func.func  @conv2d_f32 (%input:  tensor <1 x49 x42 x27 xf32 >, %weights:  tensor <28 x3 x3 x27 xf32 >, %bias:  tensor <28 xf32 >) -> () {
388-   // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> 
389-   // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]] 
390386  // CHECK: %[[M_IN:.+]] = tensor.empty() 
391387  // CHECK: %[[CST:.+]] = arith.constant 0 
392388  // CHECK: %[[FILL:.+]] = linalg.fill 
393389  // CHECK: %[[B_IN:.+]] = tensor.empty() 
394-   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf  {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]]  : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32 >) outs(%[[FILL]] : tensor<1x45x40x28xf32>) 
390+   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc  {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1  : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32 >) outs(%[[FILL]] : tensor<1x45x40x28xf32>) 
395391  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>) 
396392  // CHECK:   arith.addf 
397393  // CHECK:   linalg.yield 
@@ -408,13 +404,11 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27
408404func.func  @conv2d_dyn (%input:  tensor <?x49 x42 x27 xf32 >, %weights:  tensor <28 x3 x3 x27 xf32 >, %bias:  tensor <28 xf32 >) -> () {
409405  // CHECK: %[[C0:.+]] = arith.constant 0 
410406  // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] 
411-   // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> 
412-   // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]] 
413407  // CHECK: %[[M_IN:.+]] = tensor.empty(%[[BATCH]]) 
414408  // CHECK: %[[CST:.+]] = arith.constant 0 
415409  // CHECK: %[[FILL:.+]] = linalg.fill 
416410  // CHECK: %[[B_IN:.+]] = tensor.empty(%[[BATCH]]) 
417-   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf  {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]]  : tensor<?x49x42x27xf32>, tensor<3x3x27x28xf32 >) outs(%[[FILL]] : tensor<?x45x40x28xf32>) 
411+   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc  {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1  : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32 >) outs(%[[FILL]] : tensor<?x45x40x28xf32>) 
418412  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<?x45x40x28xf32>) outs(%[[B_IN]] : tensor<?x45x40x28xf32>) 
419413  // CHECK:   %[[ADD:.+]] = arith.addf 
420414  // CHECK:   linalg.yield %[[ADD]] : f32 
@@ -468,13 +462,11 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
468462  // CHECK: %[[W_OUT:.+]] = arith.addi %[[DIVIDED_0]], %[[ONE_0]] : index 
469463
470464  // Running convolution 
471-   // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> 
472-   // CHECK: %[[WEIGHT:.+]] = tosa.transpose %arg1, %[[PERM]] 
473465  // CHECK: %[[M_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]]) 
474466  // CHECK: %[[CST:.+]] = arith.constant 0 
475467  // CHECK: %[[FILL:.+]] = linalg.fill 
476468  // CHECK: %[[B_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]]) 
477-   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf  {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[WEIGHT]]  : tensor<1x?x?x27xf32>, tensor<3x3x27x28xf32 >) outs(%[[FILL]] : tensor<1x?x?x28xf32>) 
469+   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc  {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1  : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32 >) outs(%[[FILL]] : tensor<1x?x?x28xf32>) 
478470  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x?x?x28xf32>) outs(%[[B_IN]] : tensor<1x?x?x28xf32>) 
479471  // CHECK:   %[[ADD:.+]] = arith.addf 
480472  // CHECK:   linalg.yield %[[ADD]] : f32 
@@ -489,7 +481,7 @@ func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28
489481  // CHECK: %[[C0:.+]] = arith.constant 0 
490482  // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] 
491483  // CHECK:   tensor.yield %[[C0]] 
492-   // CHECK: linalg.conv_2d_nhwc_hwcf  
484+   // CHECK: linalg.conv_2d_nhwc_fhwc  
493485  %0  = tosa.conv2d  %input , %weights , %bias  {pad  = array<i64 : 1 , 1 , 1 , 1 >, stride  = array<i64 : 1 , 1 >, dilation  = array<i64 : 2 , 1 >} : (tensor <1 x47 x40 x28 xf32 >, tensor <28 x3 x3 x28 xf32 >, tensor <28 xf32 >) -> tensor <1 x45 x40 x28 xf32 >
494486  return 
495487}
@@ -501,7 +493,7 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
501493  // CHECK:   %[[C22:.+]] = arith.constant -22 
502494  // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] 
503495  // CHECK:   tensor.yield %[[C22]] 
504-   // CHECK: linalg.conv_2d_nhwc_hwcf_q  
496+   // CHECK: linalg.conv_2d_nhwc_fhwc_q  
505497  %0  = tosa.conv2d  %arg0 , %arg1 , %arg2  {dilation  = array<i64 : 1 , 1 >, pad  = array<i64 : 1 , 1 , 1 , 1 >, quantization_info  = #tosa.conv_quant <input_zp  = -22 , weight_zp  = 42 >, stride  = array<i64 : 1 , 1 >} : (tensor <1 x12 x12 x1 xi8 >, tensor <1024 x3 x3 x1 xi8 >, tensor <1024 xi32 >) -> tensor <1 x12 x12 x1024 xi32 >
506498  return 
507499}
0 commit comments