Skip to content

Commit 9ff71f4

Browse files
authored
[CodeGenC] Handle GlobalVar callee as internal function call (apache#15103)
Analogous to apache#14901, treat GlobalVar callees as internal function calls in CodeGenC. This specific PR doesn't provide new end-to-end functionality, as the target="c" backend isn't compiled. It does lead into allowing subroutines in any target whose codegen derives from CodeGenC, which will depend on the single-module lowering flow in apache#14985. * [CodeGenC] Added unit tests for desired behavior * [CodeGenC] Handle GlobalVar callee as internal function call * Update CodeGenC subclasses for updated interface - Call `DeclareFunction` for each `PrimFunc`, prior to any `AddFunction` calls - Provide both `GlobalVar` and `PrimFunc` to `AddFunction` calls. * Updated CRT test to expect forward declaration * Provide forward declarations for call_extern in cmsis * Avoid duplicate forward declaration C's automatic pointer cast (e.g. `void*` to `int*`) means that use of the arguments to infer the function signature may be incorrect. If a `call_extern` refers to a function within the same module, only output a single forward declaration based on the PrimFunc's parameters, not based on the CallNode's arguments. * Updated expected ptx cuda * Cast the AOT pools to the arg type * Improved tvm::GetType for tvm_access_ptr and address_of These `Call` instances can return a `PointerType(PrimType(pointee_dtype))` rather than a `PrimType(DataType::Handle())`. * [ARM][Topi] Update micro kernels to use same argument type as caller Previously, the micro kernels for gemm, avg_pool, max_pool, and tensordot relied on C's implicit type conversions for the arguments, when the caller's argument types differ from the signature's parameter types. This works, except when the codegen has auto-generated a forward declaration based on the caller's argument types, such as during AOT, which then causes a conflicting definition. Since the codegen cannot determine the functions names from the `"pragma_import_c"` in order to suppress these forward declarations, this conflict can be more easily resolved by updating the micro kernel signatures. The three types of mismatches are below. - Use of `int` or `long` parameters, whose width may vary by compiler, instead of fixed-width types. - TIR expecting the data array's integer type to also be used as an error code's return type, rather than the micro kernels' `int32_t` error code. - Pointer conversion done during argument conversion. Type conversions are done at the start of each micro kernel, to avoid changing types that are used within the computational sections of each micro kernel. * Updated unit tests with private=True Required for internal functions after PR apache#15214 * Docstring updates from review
1 parent 34cacb0 commit 9ff71f4

27 files changed

+591
-297
lines changed

python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _body():
5555
ib = tvm.tir.ir_builder.create()
5656
ib.emit(
5757
tvm.tir.call_extern(
58-
cc.dtype,
58+
"int32",
5959
f"{func_prefix}_{width}_{uniq_id}",
6060
aa.access_ptr("r"),
6161
cc.access_ptr("w"),
@@ -68,7 +68,7 @@ def _body():
6868
def _reduce_reset():
6969
ib = tvm.tir.ir_builder.create()
7070
ib.emit(
71-
tvm.tir.call_extern(cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"))
71+
tvm.tir.call_extern("int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"))
7272
)
7373
return ib.get()
7474

@@ -113,8 +113,8 @@ def sum_impl(N, uniq_id):
113113
__attribute__((always_inline)) static inline int32_t sum16_{N}_{uniq_id}(
114114
int16_t *arr,
115115
int16_t *res16,
116-
long arr_offset,
117-
int reset) {{
116+
int32_t arr_offset,
117+
int32_t reset) {{
118118
int n;
119119
int32_t *p32;
120120
int32_t res = reset ? 0 : *res16;

python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,14 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
156156
extern "C"
157157
#endif
158158
__attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_body_rest_{uniq_id}(
159-
int K,
159+
int32_t K_arg,
160160
int8_t *aa, int8_t *bb, int32_t *cc,
161-
int A_stride, int B_stride, int C_stride) {{
161+
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
162+
int K = K_arg;
163+
int A_stride = A_stride_arg;
164+
int B_stride = B_stride_arg;
165+
int C_stride = C_stride_arg;
166+
162167
int k_base = (K / 4) * 4;
163168
switch ( K % 4 ) {{
164169
case 1:
@@ -200,7 +205,12 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
200205
#endif
201206
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_loop_{uniq_id}(
202207
int8_t *aa, int8_t *bb, int32_t *cc,
203-
int A_stride, int B_stride, int C_stride) {{
208+
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
209+
int A_stride = A_stride_arg;
210+
int B_stride = B_stride_arg;
211+
int C_stride = C_stride_arg;
212+
213+
204214
for (int i = 0; i < {M}; i++) {{
205215
for (int j = 0; j < {N}; j++) {{
206216
int32_t sum = 0;
@@ -221,7 +231,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
221231
#endif
222232
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_{uniq_id}(
223233
int8_t *aa, int8_t *bb, int32_t *cc,
224-
int A_stride, int B_stride, int C_stride) {{
234+
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
235+
int A_stride = A_stride_arg;
236+
int B_stride = B_stride_arg;
237+
int C_stride = C_stride_arg;
238+
225239
int16_t bb_pad[{bb_pad_size}];
226240
int32_t retcode = 0;
227241
@@ -265,9 +279,14 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
265279
extern "C"
266280
#endif
267281
__attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_update_rest_{uniq_id}(
268-
int K,
282+
int32_t K_arg,
269283
int8_t *aa, int8_t *bb, int32_t *cc,
270-
int A_stride, int B_stride, int C_stride) {{
284+
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
285+
int K = K_arg;
286+
int A_stride = A_stride_arg;
287+
int B_stride = B_stride_arg;
288+
int C_stride = C_stride_arg;
289+
271290
int k_base = (K / 4) * 4;
272291
switch ( K % 4 ) {{
273292
case 1:
@@ -309,7 +328,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
309328
#endif
310329
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_loop_{uniq_id}(
311330
int8_t *aa, int8_t *bb, int32_t *cc,
312-
int A_stride, int B_stride, int C_stride) {{
331+
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
332+
int A_stride = A_stride_arg;
333+
int B_stride = B_stride_arg;
334+
int C_stride = C_stride_arg;
335+
313336
for (int i = 0; i < {M}; i++) {{
314337
for (int j = 0; j < {N}; j++) {{
315338
int32_t sum = 0;
@@ -327,7 +350,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
327350
#endif
328351
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_{uniq_id}(
329352
int8_t *aa, int8_t *bb, int32_t *cc,
330-
int A_stride, int B_stride, int C_stride) {{
353+
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
354+
int A_stride = A_stride_arg;
355+
int B_stride = B_stride_arg;
356+
int C_stride = C_stride_arg;
357+
331358
int16_t bb_pad[{bb_pad_size}];
332359
int32_t retcode = 0;
333360
@@ -368,9 +395,14 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
368395
extern "C"
369396
#endif
370397
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_body_rest_{uniq_id}(
371-
int K,
398+
int32_t K_arg,
372399
int16_t *aa, int16_t *bb, int32_t *cc,
373-
int A_stride, int B_stride, int C_stride) {{
400+
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
401+
int K = K_arg;
402+
int A_stride = A_stride_arg;
403+
int B_stride = B_stride_arg;
404+
int C_stride = C_stride_arg;
405+
374406
int k_base = (K / 2) * 2;
375407
for (int i = 0; i < {M}; i++) {{
376408
for (int j = 0; j < {N}; j++) {{
@@ -387,7 +419,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
387419
#endif
388420
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_loop_{uniq_id}(
389421
int16_t *aa, int16_t *bb, int32_t *cc,
390-
int A_stride, int B_stride, int C_stride) {{
422+
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
423+
int A_stride = A_stride_arg;
424+
int B_stride = B_stride_arg;
425+
int C_stride = C_stride_arg;
426+
391427
for (int i = 0; i < {M}; i++) {{
392428
for (int j = 0; j < {N}; j++) {{
393429
int32_t sum = 0;
@@ -408,7 +444,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
408444
#endif
409445
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_{uniq_id}(
410446
int16_t *aa, int16_t *bb, int32_t *cc,
411-
int A_stride, int B_stride, int C_stride) {{
447+
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
448+
int A_stride = A_stride_arg;
449+
int B_stride = B_stride_arg;
450+
int C_stride = C_stride_arg;
451+
412452
int32_t retcode = 0;
413453
414454
if ( {M} < 2 && {N} < 2 ) {{
@@ -450,9 +490,14 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
450490
extern "C"
451491
#endif
452492
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_update_rest_{uniq_id}(
453-
int K,
493+
int32_t K_arg,
454494
int16_t *aa, int16_t *bb, int32_t *cc,
455-
int A_stride, int B_stride, int C_stride) {{
495+
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
496+
int K = K_arg;
497+
int A_stride = A_stride_arg;
498+
int B_stride = B_stride_arg;
499+
int C_stride = C_stride_arg;
500+
456501
int k_base = (K / 2) * 2;
457502
for (int i = 0; i < {M}; i++) {{
458503
for (int j = 0; j < {N}; j++) {{
@@ -469,7 +514,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
469514
#endif
470515
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_loop_{uniq_id}(
471516
int16_t *aa, int16_t *bb, int32_t *cc,
472-
int A_stride, int B_stride, int C_stride) {{
517+
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
518+
int A_stride = A_stride_arg;
519+
int B_stride = B_stride_arg;
520+
int C_stride = C_stride_arg;
521+
473522
for (int i = 0; i < {M}; i++) {{
474523
for (int j = 0; j < {N}; j++) {{
475524
int32_t sum = 0;
@@ -487,7 +536,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
487536
#endif
488537
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_{uniq_id}(
489538
int16_t *aa, int16_t *bb, int32_t *cc,
490-
int A_stride, int B_stride, int C_stride) {{
539+
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
540+
int A_stride = A_stride_arg;
541+
int B_stride = B_stride_arg;
542+
int C_stride = C_stride_arg;
543+
491544
int32_t retcode = 0;
492545
493546
if ( {M} < 2 && {N} < 2 ) {{
@@ -520,7 +573,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
520573
#ifdef __cplusplus
521574
extern "C"
522575
#endif
523-
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int C_stride) {{
576+
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int32_t C_stride) {{
524577
for (int i = 0; i < {M}; i++) {{
525578
for (int j = 0; j < {N}; j++) {{
526579
cc[i*C_stride + j] = 0;

python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _body():
4646
ib = tvm.tir.ir_builder.create()
4747
ib.emit(
4848
tvm.tir.call_extern(
49-
cc.dtype,
49+
"int32",
5050
f"{func_prefix}_{uniq_id}",
5151
aa.access_ptr("r"),
5252
cc.access_ptr("w"),
@@ -59,7 +59,7 @@ def _reduce_reset():
5959
ib = tvm.tir.ir_builder.create()
6060
ib.emit(
6161
tvm.tir.call_extern(
62-
cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0]
62+
"int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0]
6363
)
6464
)
6565
return ib.get()
@@ -96,7 +96,7 @@ def max_impl(uniq_id):
9696
#endif
9797
__attribute__((always_inline)) static inline int32_t max8_reset_{uniq_id}(
9898
int8_t *res,
99-
int N) {{
99+
int32_t N) {{
100100
memset(res, (int8_t)-128, N * sizeof(*res));
101101
return 0;
102102
}}
@@ -107,7 +107,9 @@ def max_impl(uniq_id):
107107
__attribute__((always_inline)) static inline int32_t max8_loop_{uniq_id}(
108108
int8_t *arg,
109109
int8_t *res,
110-
int N) {{
110+
int32_t N_arg) {{
111+
int N = N_arg;
112+
111113
for ( int i = 0; i < N; ++ i )
112114
if ( arg[i] > res[i] )
113115
res[i] = arg[i];
@@ -120,7 +122,8 @@ def max_impl(uniq_id):
120122
__attribute__((always_inline)) static inline int32_t max8_{uniq_id}(
121123
int8_t *arg,
122124
int8_t *res,
123-
int N) {{
125+
int32_t N_arg) {{
126+
int N = N_arg;
124127
int32_t *parg32, *pres32;
125128
int una_arg = (int32_t)arg & 0x3, una_res = (int32_t)res & 0x3;
126129
int32_t retcode = 0;

python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,8 +390,13 @@ def insert_lines(lines):
390390
#define {function_name.upper()}_EXISTS
391391
#include <arm_acle.h>
392392
__attribute__((always_inline)) static inline int32_t {function_name}(
393-
int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
393+
int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
394+
int32_t *bias, int32_t *scale
394395
) {{
396+
int32_t *output = output_arg;
397+
int32_t *tensor = tensor_arg;
398+
int32_t *kernel = kernel_arg;
399+
395400
{_init_biased_accumulators(num_outputs)}
396401
397402
{insert_lines(load_tensor_lines)}

src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,6 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
4646
CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str, devices);
4747
}
4848

49-
/*!
50-
* \brief Emit code that offloads a subgraph to the Cortex-M
51-
*
52-
* \return string of code that offloads a subgraph to the Cortex-M
53-
*/
54-
void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); }
55-
5649
private:
5750
/*! * \brief Enable storing the last error */
5851
bool debug_last_error;
@@ -519,11 +512,11 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
519512
bool emit_fwd_func_decl = false;
520513
bool debug_last_error = GetCompilerAttrs()->debug_last_error;
521514
CodeGenCMSISNN codegen;
522-
Array<String> function_names;
523515
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), debug_last_error);
524-
std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
525-
for (auto kv : mod->functions) {
526-
funcs.push_back(kv);
516+
517+
std::vector<std::pair<tvm::GlobalVar, tvm::PrimFunc>> funcs;
518+
for (auto [gvar, base_func] : mod->functions) {
519+
funcs.push_back({gvar, Downcast<PrimFunc>(base_func)});
527520
}
528521

529522
std::sort(funcs.begin(), funcs.end(),
@@ -538,13 +531,16 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
538531
return name_hint_a < name_hint_b;
539532
});
540533

541-
for (auto kv : funcs) {
542-
auto prim_func = Downcast<PrimFunc>(kv.second);
543-
auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
544-
function_names.push_back(global_symbol.value());
545-
codegen.AddFunction(prim_func);
534+
for (auto [gvar, prim_func] : funcs) {
535+
codegen.AddFunction(gvar, prim_func);
546536
}
547537
std::string code = codegen.Finish();
538+
539+
Array<String> function_names;
540+
for (auto [gvar, prim_func] : funcs) {
541+
function_names.push_back(codegen.GetFunctionName(gvar));
542+
}
543+
548544
return codegen::CSourceModuleCreate(code, "c", function_names);
549545
}
550546

src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,30 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
4949
bool emit_asserts = false;
5050
bool emit_fwd_func_decl = false;
5151
CodeGenExampleTargetHook codegen;
52-
Array<String> function_names;
52+
5353
std::unordered_set<std::string> devices;
5454
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices);
55-
for (auto kv : mod->functions) {
56-
auto prim_func = Downcast<PrimFunc>(kv.second);
57-
auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
58-
function_names.push_back(global_symbol.value());
59-
codegen.AddFunction(prim_func);
55+
56+
Map<GlobalVar, PrimFunc> functions;
57+
for (auto [gvar, base_func] : mod->functions) {
58+
auto prim_func = Downcast<PrimFunc>(base_func);
59+
functions.Set(gvar, prim_func);
60+
}
61+
62+
for (auto [gvar, prim_func] : functions) {
63+
codegen.DeclareFunction(gvar, prim_func);
64+
}
65+
for (auto [gvar, prim_func] : functions) {
66+
codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl);
6067
}
68+
6169
std::string code = codegen.Finish();
70+
71+
Array<String> function_names;
72+
for (auto [gvar, prim_func] : functions) {
73+
function_names.push_back(codegen.GetFunctionName(gvar));
74+
}
75+
6276
return codegen::CSourceModuleCreate(code, "c", function_names);
6377
}
6478

0 commit comments

Comments
 (0)