@@ -363,11 +363,13 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
363363
364364// CHECK-LABEL: @conv2d_i8
365365func.func @conv2d_i8 (%input: tensor <1 x49 x42 x27 xi8 >, %weights: tensor <28 x1 x1 x27 xi8 >, %bias: tensor <28 xi8 >) -> () {
366+ // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
367+ // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
366368 // CHECK: %[[M_IN:.+]] = tensor.empty()
367369 // CHECK: %[[CST:.+]] = arith.constant 0
368370 // CHECK: %[[FILL:.+]] = linalg.fill
369371 // CHECK: %[[B_IN:.+]] = tensor.empty()
370- // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 , %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8 >, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
372+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] , %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8 >, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
371373 // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>)
372374 // CHECK: arith.extsi
373375 // CHECK: arith.addi
@@ -383,11 +385,13 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
383385
384386// CHECK-LABEL: @conv2d_f32
385387func.func @conv2d_f32 (%input: tensor <1 x49 x42 x27 xf32 >, %weights: tensor <28 x3 x3 x27 xf32 >, %bias: tensor <28 xf32 >) -> () {
388+ // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
389+ // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
386390 // CHECK: %[[M_IN:.+]] = tensor.empty()
387391 // CHECK: %[[CST:.+]] = arith.constant 0
388392 // CHECK: %[[FILL:.+]] = linalg.fill
389393 // CHECK: %[[B_IN:.+]] = tensor.empty()
390- // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32 >) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
394+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32 >) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
391395 // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
392396 // CHECK: arith.addf
393397 // CHECK: linalg.yield
@@ -404,11 +408,13 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27
404408func.func @conv2d_dyn (%input: tensor <?x49 x42 x27 xf32 >, %weights: tensor <28 x3 x3 x27 xf32 >, %bias: tensor <28 xf32 >) -> () {
405409 // CHECK: %[[C0:.+]] = arith.constant 0
406410 // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
411+ // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
412+ // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
407413 // CHECK: %[[M_IN:.+]] = tensor.empty(%[[BATCH]])
408414 // CHECK: %[[CST:.+]] = arith.constant 0
409415 // CHECK: %[[FILL:.+]] = linalg.fill
410416 // CHECK: %[[B_IN:.+]] = tensor.empty(%[[BATCH]])
411- // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32 >) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
417+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<?x49x42x27xf32>, tensor<3x3x27x28xf32 >) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
412418 // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<?x45x40x28xf32>) outs(%[[B_IN]] : tensor<?x45x40x28xf32>)
413419 // CHECK: %[[ADD:.+]] = arith.addf
414420 // CHECK: linalg.yield %[[ADD]] : f32
@@ -462,11 +468,13 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
462468 // CHECK: %[[W_OUT:.+]] = arith.addi %[[DIVIDED_0]], %[[ONE_0]] : index
463469
464470 // Running convolution
471+ // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
472+ // CHECK: %[[WEIGHT:.+]] = tosa.transpose %arg1, %[[PERM]]
465473 // CHECK: %[[M_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
466474 // CHECK: %[[CST:.+]] = arith.constant 0
467475 // CHECK: %[[FILL:.+]] = linalg.fill
468476 // CHECK: %[[B_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
469- // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32 >) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
477+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[WEIGHT]] : tensor<1x?x?x27xf32>, tensor<3x3x27x28xf32 >) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
470478 // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x?x?x28xf32>) outs(%[[B_IN]] : tensor<1x?x?x28xf32>)
471479 // CHECK: %[[ADD:.+]] = arith.addf
472480 // CHECK: linalg.yield %[[ADD]] : f32
@@ -481,7 +489,7 @@ func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28
481489 // CHECK: %[[C0:.+]] = arith.constant 0
482490 // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
483491 // CHECK: tensor.yield %[[C0]]
484- // CHECK: linalg.conv_2d_nhwc_fhwc
492+ // CHECK: linalg.conv_2d_nhwc_hwcf
485493 %0 = tosa.conv2d %input , %weights , %bias {pad = array<i64 : 1 , 1 , 1 , 1 >, stride = array<i64 : 1 , 1 >, dilation = array<i64 : 2 , 1 >} : (tensor <1 x47 x40 x28 xf32 >, tensor <28 x3 x3 x28 xf32 >, tensor <28 xf32 >) -> tensor <1 x45 x40 x28 xf32 >
486494 return
487495}
@@ -493,7 +501,7 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
493501 // CHECK: %[[C22:.+]] = arith.constant -22
494502 // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
495503 // CHECK: tensor.yield %[[C22]]
496- // CHECK: linalg.conv_2d_nhwc_fhwc_q
504+ // CHECK: linalg.conv_2d_nhwc_hwcf_q
497505 %0 = tosa.conv2d %arg0 , %arg1 , %arg2 {dilation = array<i64 : 1 , 1 >, pad = array<i64 : 1 , 1 , 1 , 1 >, quantization_info = #tosa.conv_quant <input_zp = -22 , weight_zp = 42 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x12 x12 x1 xi8 >, tensor <1024 x3 x3 x1 xi8 >, tensor <1024 xi32 >) -> tensor <1 x12 x12 x1024 xi32 >
498506 return
499507}
0 commit comments