@@ -466,3 +466,102 @@ def matmul_as_contract_op(
466466 )
467467
468468 print (module )
469+
470+ # CHECK-LABEL: TEST: testBatchMatmulOp
471+ @run
472+ def testBatchMatmulOp ():
473+ with Context (), Location .unknown ():
474+ module = Module .create ()
475+ f32 = F32Type .get ()
476+ with InsertionPoint (module .body ):
477+ a_shape = (2 , 4 , 8 )
478+ b_shape = (2 , 8 , 12 )
479+ b_transposed_shape = (2 , 12 , 8 )
480+ c_shape = (2 , 4 , 12 )
481+
482+ dimBatch = ir .AffineDimExpr .get (0 )
483+ dimM = ir .AffineDimExpr .get (1 )
484+ dimN = ir .AffineDimExpr .get (2 )
485+ dimK = ir .AffineDimExpr .get (3 )
486+
487+ # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
488+ # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
489+ # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
490+
491+ a_map = ir .AffineMap .get (4 , 0 , [dimBatch , dimM , dimK ])
492+ b_transposed_map = ir .AffineMap .get (4 , 0 , [dimBatch , dimN , dimK ])
493+ c_map = ir .AffineMap .get (4 , 0 , [dimBatch , dimM , dimN ])
494+
495+ # CHECK: func.func @batch_matmul_op(
496+ @func .FuncOp .from_py_func (
497+ # CHECK-SAME: %[[A:.*]]: tensor<2x4x8xf32>,
498+ RankedTensorType .get (a_shape , f32 ),
499+ # CHECK-SAME: %[[Amem:.*]]: memref<2x4x8xf32>,
500+ MemRefType .get (a_shape , f32 ),
501+ # CHECK-SAME: %[[B:.*]]: tensor<2x8x12xf32>,
502+ RankedTensorType .get (b_shape , f32 ),
503+ # CHECK-SAME: %[[Bmem:.*]]: memref<2x8x12xf32>,
504+ MemRefType .get (b_shape , f32 ),
505+ # CHECK-SAME: %[[BTrans:.*]]: tensor<2x12x8xf32>,
506+ RankedTensorType .get (b_transposed_shape , f32 ),
507+ # CHECK-SAME: %[[BTransmem:.*]]: memref<2x12x8xf32>,
508+ MemRefType .get (b_transposed_shape , f32 ),
509+ # CHECK-SAME: %[[C:.*]]: tensor<2x4x12xf32>,
510+ RankedTensorType .get (c_shape , f32 ),
511+ # CHECK-SAME: %[[Cmem:.*]]: memref<2x4x12xf32>)
512+ MemRefType .get (c_shape , f32 ),
513+ )
514+ def batch_matmul_op (A , Amem , B , Bmem , Btransposed , Btransposedmem , C , Cmem ):
515+ # CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x8x12xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
516+ res = linalg .BatchMatmulOp (
517+ result_tensors = (C .type ,),
518+ inputs = (A , B ),
519+ outputs = (C ,),
520+ )
521+ linalg .fill_builtin_region (res .operation )
522+ # CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x8x12xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
523+ res = linalg .batch_matmul (A , B , outs = (C ,))
524+
525+ # CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<2x4x8xf32>, tensor<2x12x8xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
526+ res = linalg .BatchMatmulOp (
527+ result_tensors = (C .type ,),
528+ inputs = (A , Btransposed ),
529+ outputs = (C ,),
530+ indexing_maps = [a_map , b_transposed_map , c_map ],
531+ )
532+ linalg .fill_builtin_region (res .operation )
533+ # CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<2x4x8xf32>, tensor<2x12x8xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
534+ res = linalg .batch_matmul (
535+ A ,
536+ Btransposed ,
537+ outs = (C ,),
538+ indexing_maps = [a_map , b_transposed_map , c_map ],
539+ )
540+
541+ # CHECK: linalg.batch_matmul ins(%[[Amem]], %[[Bmem]] : memref<2x4x8xf32>, memref<2x8x12xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
542+ res = linalg .BatchMatmulOp (
543+ result_tensors = [],
544+ inputs = (Amem , Bmem ),
545+ outputs = (Cmem ,),
546+ )
547+ linalg .fill_builtin_region (res .operation )
548+ # CHECK: linalg.batch_matmul ins(%[[Amem]], %[[Bmem]] : memref<2x4x8xf32>, memref<2x8x12xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
549+ linalg .batch_matmul (Amem , Bmem , outs = (Cmem ,))
550+
551+ # CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<2x4x8xf32>, memref<2x12x8xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
552+ res = linalg .BatchMatmulOp (
553+ result_tensors = [],
554+ inputs = (Amem , Btransposedmem ),
555+ outputs = (Cmem ,),
556+ indexing_maps = [a_map , b_transposed_map , c_map ],
557+ )
558+ linalg .fill_builtin_region (res .operation )
559+ # CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<2x4x8xf32>, memref<2x12x8xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
560+ linalg .batch_matmul (
561+ Amem ,
562+ Btransposedmem ,
563+ outs = (Cmem ,),
564+ indexing_maps = [a_map , b_transposed_map , c_map ],
565+ )
566+
567+ print (module )
0 commit comments