Skip to content

Commit 60f5568

Browse files
authored
[CODEGEN][REFACTOR] tir.call_llvm_intrin to remove nargs (#18206)
This PR refactors the tir.call_llvm_intrin to omit the first nargs argument in the beginning. Previously the nargs was introduced when prefetch have different number of signature. The previous reason no longer stands as of now, and it is less intuitive to attach nargs for the call_llvm_intrin, where nargs directly appears in number of argument. After the update, tir.call_llvm_intrin can directly pass in the arguments as it is.
1 parent dbbcf90 commit 60f5568

25 files changed

+55
-128
lines changed

include/tvm/tir/builtin.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ TVM_DLL const Op& call_spirv_pure_glsl450();
225225
// TODO(tvm-team) revisit the builtins below
226226
// some of them can simply become ops with special codegen attr.
227227
/*!
228-
* \brief Prefetch a cacheline
228+
* \brief same signature as llvm.prefetch
229229
*/
230230
TVM_DLL const Op& prefetch();
231231

include/tvm/tir/stmt.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,11 +1106,6 @@ constexpr const char* pragma_import_c = "pragma_import_c";
11061106
constexpr const char* pragma_import_llvm = "pragma_import_llvm";
11071107
/*! \brief Try to modify the AST to support Tensor Core */
11081108
constexpr const char* pragma_tensor_core = "pragma_tensor_core";
1109-
/*!
1110-
* \brief Mark of prefetch scope, value=offset,
1111-
* run prefetch of Tensor on the current loop scope
1112-
*/
1113-
constexpr const char* prefetch_scope = "prefetch_scope";
11141109
/*!
11151110
* \brief Marks the layout transforms to be used for a tensor.
11161111
*

python/tvm/tir/tensor_intrin/arm_cpu.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,13 @@ def neon_4x4_i8i8i32_impl(
7474

7575
multiply_low = T.call_llvm_pure_intrin(
7676
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
77-
T.uint32(2),
7877
vec_a,
7978
vec_b_low,
8079
dtype="int16x8",
8180
)
8281

8382
pairwise_reduction_low = T.call_llvm_pure_intrin(
8483
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
85-
T.uint32(1),
8684
multiply_low,
8785
dtype="int32x4",
8886
)
@@ -91,22 +89,19 @@ def neon_4x4_i8i8i32_impl(
9189

9290
multiply_high = T.call_llvm_pure_intrin(
9391
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
94-
T.uint32(2),
9592
vec_a,
9693
vec_b_high,
9794
dtype="int16x8",
9895
)
9996

10097
pairwise_reduction_high = T.call_llvm_pure_intrin(
10198
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
102-
T.uint32(1),
10399
multiply_high,
104100
dtype="int32x4",
105101
)
106102

107103
C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin(
108104
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.addp.v4i32"),
109-
T.uint32(2),
110105
pairwise_reduction_low,
111106
pairwise_reduction_high,
112107
dtype="int32x4",
@@ -159,7 +154,6 @@ def dot_prod_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
159154

160155
C[T.ramp(T.int32(0), 1, 4)] = T.call_llvm_pure_intrin(
161156
T.llvm_lookup_intrinsic_id(f"llvm.aarch64.neon.{instr}"),
162-
T.uint32(3),
163157
vec_c,
164158
vec_a,
165159
vec_b,
@@ -311,7 +305,6 @@ def impl():
311305
T.call_llvm_intrin(
312306
"void",
313307
"llvm.aarch64.sme.ld1w.horiz",
314-
T.uint32(4),
315308
predicate,
316309
input_ptr,
317310
sub_tile,
@@ -335,7 +328,6 @@ def impl():
335328
T.call_llvm_intrin(
336329
"void",
337330
"llvm.aarch64.sme.st1w.vert",
338-
T.uint32(4),
339331
predicate,
340332
output_ptr,
341333
sub_tile,
@@ -438,7 +430,6 @@ def impl():
438430
T.call_llvm_intrin(
439431
"void",
440432
"llvm.aarch64.sme.ld1h.horiz",
441-
T.uint32(4),
442433
ptrue_fp16,
443434
input_ptr,
444435
sub_tile_idx,
@@ -450,7 +441,6 @@ def impl():
450441
T.call_llvm_intrin(
451442
"void",
452443
"llvm.aarch64.sme.ld1h.horiz",
453-
T.uint32(4),
454444
ptrue_fp16,
455445
input_ptr,
456446
sub_tile_idx,
@@ -467,7 +457,6 @@ def impl():
467457
T.call_llvm_intrin(
468458
"void",
469459
"llvm.aarch64.sme.st1w.vert",
470-
T.uint32(4),
471460
ptrue_fp32,
472461
output_ptr,
473462
sub_tile_idx,
@@ -479,7 +468,6 @@ def impl():
479468
T.call_llvm_intrin(
480469
"void",
481470
"llvm.aarch64.sme.st1w.vert",
482-
T.uint32(4),
483471
ptrue_fp32,
484472
output_ptr,
485473
sub_tile_idx + 2,
@@ -692,7 +680,6 @@ def impl():
692680
T.call_llvm_intrin(
693681
"void",
694682
fmopa_intrin,
695-
T.uint32(5),
696683
sub_tile,
697684
input_1[1],
698685
input_2[1],
@@ -713,7 +700,6 @@ def impl():
713700
T.call_llvm_intrin(
714701
"void",
715702
"llvm.aarch64.sme.st1w.horiz",
716-
T.uint32(4),
717703
_create_active_lane_mask(
718704
C, (vert_offset + slice_idx, horiz_offset), M
719705
),
@@ -752,9 +738,7 @@ def impl(c: T.handle) -> None:
752738
T.reads()
753739
T.writes(C[0:SVF2, 0:SVF2])
754740
clear_all_tiles = T.int32(255)
755-
T.evaluate(
756-
T.call_llvm_intrin("void", "llvm.aarch64.sme.zero", T.uint32(1), clear_all_tiles)
757-
)
741+
T.evaluate(T.call_llvm_intrin("void", "llvm.aarch64.sme.zero", clear_all_tiles))
758742

759743
return desc, impl
760744

python/tvm/tir/tensor_intrin/hexagon.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def dot_product_32x4_u8u8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> Non
107107

108108
C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
109109
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyub.acc.128B"),
110-
T.uint32(3),
111110
C[T.ramp(T.int32(0), 1, 32)],
112111
B_i32x32,
113112
A_i32,
@@ -149,7 +148,6 @@ def dot_product_32x4_u8i8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> Non
149148

150149
C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
151150
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpybusv.acc.128B"),
152-
T.uint32(3),
153151
C[T.ramp(T.int32(0), 1, 32)],
154152
T.broadcast(A_i32, 32),
155153
B_i32x32,
@@ -191,7 +189,6 @@ def dot_product_32x2_i16i16i32_vdmpy(a: T.handle, b: T.handle, c: T.handle) -> N
191189

192190
C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
193191
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vdmpyhvsat.acc.128B"),
194-
T.uint32(3),
195192
C[T.ramp(T.int32(0), 1, 32)],
196193
T.Broadcast(A_i32, 32),
197194
B_i32x32,

python/tvm/tir/tensor_intrin/rocm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def sdot4(
3939

4040
C[0] += T.call_llvm_pure_intrin(
4141
T.llvm_lookup_intrinsic_id("llvm.amdgcn.sdot4"),
42-
T.uint32(4),
4342
T.reinterpret(A.vload([0], "int8x4"), dtype="int32"),
4443
T.reinterpret(B.vload([0], "int8x4"), dtype="int32"),
4544
T.int32(0),
@@ -337,7 +336,6 @@ def mfma_sync_impl_float(a: T.handle, b: T.handle, c: T.handle) -> None:
337336
T.launch_thread(tx, WARP_SIZE)
338337
C[tx, 0:local_size_out] = T.call_llvm_pure_intrin(
339338
T.llvm_lookup_intrinsic_id(mfma_intrin),
340-
T.uint32(6),
341339
A[tx, 0:local_size],
342340
B[tx, 0:local_size],
343341
C[tx, 0:local_size_out],
@@ -365,7 +363,6 @@ def mfma_sync_impl_integer(a: T.handle, b: T.handle, c: T.handle) -> None:
365363

366364
C[tx, 0:local_size_out] = T.call_llvm_pure_intrin(
367365
T.llvm_lookup_intrinsic_id(mfma_intrin),
368-
T.uint32(6),
369366
T.call_intrin("int32", "tir.reinterpret", A[tx, 0:local_size]),
370367
T.call_intrin("int32", "tir.reinterpret", A[tx, 0:local_size]),
371368
C[tx, 0:local_size_out],

python/tvm/tir/tensor_intrin/x86.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def dot_product_16x4_u8i8i32_vnni(
5959

6060
C[T.ramp(T.int32(0), 1, 16)] = T.call_llvm_pure_intrin(
6161
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
62-
T.uint32(3),
6362
C_i32x16,
6463
T.broadcast(A_i32, 16),
6564
B_i32x16,
@@ -86,15 +85,13 @@ def dot_product_16x4_u8i8i32_avx512(
8685

8786
Red = T.call_llvm_pure_intrin(
8887
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.pmaddubs.w.512"),
89-
T.uint32(2),
9088
A_u8x64,
9189
B_i8x64,
9290
dtype="int16x32",
9391
)
9492

9593
C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin(
9694
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.pmaddw.d.512"),
97-
T.uint32(2),
9895
Red,
9996
T.int16x32(1),
10097
dtype="int32x16",

src/target/llvm/codegen_arm.cc

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
6767

6868
PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
6969
using namespace tir;
70-
const PrimExpr& e = call->args[2];
70+
const PrimExpr& e = call->args[1];
7171
llvm::Intrinsic::ID ctpop_id = llvm::Intrinsic::ctpop;
7272
llvm::Intrinsic::ID vpaddlu_id = llvm::Intrinsic::arm_neon_vpaddlu;
7373

@@ -77,7 +77,6 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
7777
(total_size != 128 && total_size != 64)) {
7878
Array<PrimExpr> vcnt_args;
7979
vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
80-
vcnt_args.push_back(IntImm(DataType::UInt(32), 1));
8180
vcnt_args.push_back(e);
8281
return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt_args);
8382
}
@@ -101,14 +100,12 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
101100
ICHECK(c0 != nullptr);
102101
Array<PrimExpr> vcnt8_args;
103102
vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
104-
vcnt8_args.push_back(IntImm(DataType::UInt(32), 1));
105103
vcnt8_args.push_back(input8);
106104
PrimExpr vcnt8 = tir::Call(uint8_type, builtin_call_llvm_pure_intrin_, vcnt8_args);
107105

108106
// Accumulation 8->16bit
109107
Array<PrimExpr> vcnt16_args;
110108
vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
111-
vcnt16_args.push_back(IntImm(DataType::UInt(32), 1));
112109
vcnt16_args.push_back(vcnt8);
113110
PrimExpr vcnt16 = tir::Call(uint16_type, builtin_call_llvm_pure_intrin_, vcnt16_args);
114111
if (call->dtype.bits() == 16) {
@@ -118,7 +115,6 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
118115
// Accumulation 16->32bit
119116
Array<PrimExpr> vcnt32_args;
120117
vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
121-
vcnt32_args.push_back(IntImm(DataType::UInt(32), 1));
122118
vcnt32_args.push_back(vcnt16);
123119
PrimExpr vcnt32 = tir::Call(uint32_type, builtin_call_llvm_pure_intrin_, vcnt32_args);
124120
if (call->dtype.bits() == 32) {
@@ -128,7 +124,6 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
128124
// Accumulation 32->64bit
129125
Array<PrimExpr> vcnt64_args;
130126
vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
131-
vcnt64_args.push_back(IntImm(DataType::UInt(32), 1));
132127
vcnt64_args.push_back(vcnt32);
133128
return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt64_args);
134129
}

src/target/llvm/codegen_llvm.cc

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,34 +1359,18 @@ void CodeGenLLVM::EmitFloat16ConversionBuiltins(bool use_float16_abi) {
13591359

13601360
llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
13611361
if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) {
1362-
ICHECK_GE(op->args.size(), 2U);
1362+
ICHECK_GE(op->args.size(), 1U);
13631363
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
1364-
int64_t num_signature = Downcast<IntImm>(op->args[1])->value;
13651364
std::vector<llvm::Value*> arg_value;
13661365
std::vector<llvm::Type*> arg_type;
1367-
for (size_t i = 2; i < op->args.size(); ++i) {
1366+
for (size_t i = 1; i < op->args.size(); ++i) {
13681367
arg_value.push_back(MakeValue(op->args[i]));
1369-
if (i - 2 < static_cast<size_t>(num_signature)) {
1370-
arg_type.push_back(arg_value.back()->getType());
1371-
}
1368+
arg_type.push_back(arg_value.back()->getType());
13721369
}
1373-
// LLVM's prefetch intrinsic returns "void", while TVM's prefetch
1374-
// returns int32. This causes problems because prefetch is one of
1375-
// those intrinsics that is generated automatically via the
1376-
// tvm.intrin.rule mechanism. Any other intrinsic with a type
1377-
// mismatch will have to be treated specially here.
1378-
// TODO(kparzysz-quic): fix this once TVM prefetch uses the same
1379-
// type as LLVM.
1380-
llvm::Type* return_type =
1381-
(id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef<PrimExpr>(op)) : t_void_;
1370+
llvm::Type* return_type = GetLLVMType(GetRef<PrimExpr>(op));
13821371
llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type);
13831372
ICHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: "
1384-
#if TVM_LLVM_VERSION >= 130
1385-
<< llvm::Intrinsic::getBaseName(id).str();
1386-
#else
1387-
<< llvm::Intrinsic::getName(id, {});
1388-
#endif
1389-
1373+
<< llvmGetIntrinName(id);
13901374
// In earlier versions of LLVM's, the prefetch intrinsic is not
13911375
// overloaded, and always takes the first argument as i8*. If
13921376
// this is the case, this argument should insert a cast to i8*.
@@ -1399,7 +1383,6 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
13991383
builder_->CreatePointerCast(arg_value[0], llvmGetPointerTo(t_char_, addrspace));
14001384
}
14011385
}
1402-
14031386
return builder_->CreateCall(f, arg_value);
14041387
} else if (op->op.same_as(builtin::bitwise_and())) {
14051388
return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1]));

src/target/llvm/intrin_rule_llvm.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,6 @@ TVM_REGISTER_OP("tir.clz").set_attr<FLegalize>("llvm.FLegalize", [](const PrimEx
266266
ICHECK_EQ(call->args.size(), 1);
267267
Array<PrimExpr> cargs;
268268
cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz));
269-
cargs.push_back(IntImm(DataType::UInt(32), 2));
270269
cargs.push_back(call->args[0]);
271270
cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef
272271
// LLVM requires that the return type must match the first argument type

src/target/llvm/intrin_rule_llvm.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,14 @@
2626

2727
#ifdef TVM_LLVM_VERSION
2828

29+
#include <llvm/IR/Intrinsics.h>
2930
#include <tvm/ffi/function.h>
3031
#include <tvm/target/codegen.h>
3132
#include <tvm/tir/builtin.h>
3233
#include <tvm/tir/expr.h>
3334

35+
#include "llvm_instance.h"
36+
3437
namespace tvm {
3538
namespace codegen {
3639
// num_signature means number of arguments used to query signature
@@ -41,7 +44,9 @@ inline PrimExpr DispatchLLVMPureIntrin(const PrimExpr& e) {
4144
Array<PrimExpr> cargs;
4245
// intrin id.
4346
cargs.push_back(IntImm(DataType::UInt(32), id));
44-
cargs.push_back(IntImm(DataType::UInt(32), num_signature));
47+
ICHECK_EQ(call->args.size(), num_signature)
48+
<< "llvm.call_llvm_intrin" << llvmGetIntrinName(id) << "expects " << num_signature
49+
<< " arguments, but got " << call->args.size();
4550

4651
for (PrimExpr arg : call->args) {
4752
cargs.push_back(arg);
@@ -56,7 +61,9 @@ inline PrimExpr DispatchLLVMIntrin(const PrimExpr& e) {
5661
Array<PrimExpr> cargs;
5762
// intrin id.
5863
cargs.push_back(IntImm(DataType::UInt(32), id));
59-
cargs.push_back(IntImm(DataType::UInt(32), num_signature));
64+
ICHECK_EQ(call->args.size(), num_signature)
65+
<< "llvm.call_llvm_intrin" << llvmGetIntrinName(id) << "expects " << num_signature
66+
<< " arguments, but got " << call->args.size();
6067
for (PrimExpr arg : call->args) {
6168
cargs.push_back(arg);
6269
}

0 commit comments

Comments
 (0)