Skip to content

Commit e88d0d4

Browse files
authored
Revert "[CodeGenC] Handle GlobalVar callee as internal function call" (#15725)
Revert "[CodeGenC] Handle GlobalVar callee as internal function call (#15103)" This reverts commit 9ff71f4, a recent change that breaks the Metal backend.
1 parent e3055c1 commit e88d0d4

27 files changed

+299
-591
lines changed

python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _body():
5555
ib = tvm.tir.ir_builder.create()
5656
ib.emit(
5757
tvm.tir.call_extern(
58-
"int32",
58+
cc.dtype,
5959
f"{func_prefix}_{width}_{uniq_id}",
6060
aa.access_ptr("r"),
6161
cc.access_ptr("w"),
@@ -68,7 +68,7 @@ def _body():
6868
def _reduce_reset():
6969
ib = tvm.tir.ir_builder.create()
7070
ib.emit(
71-
tvm.tir.call_extern("int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"))
71+
tvm.tir.call_extern(cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"))
7272
)
7373
return ib.get()
7474

@@ -113,8 +113,8 @@ def sum_impl(N, uniq_id):
113113
__attribute__((always_inline)) static inline int32_t sum16_{N}_{uniq_id}(
114114
int16_t *arr,
115115
int16_t *res16,
116-
int32_t arr_offset,
117-
int32_t reset) {{
116+
long arr_offset,
117+
int reset) {{
118118
int n;
119119
int32_t *p32;
120120
int32_t res = reset ? 0 : *res16;

python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py

Lines changed: 17 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,9 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
156156
extern "C"
157157
#endif
158158
__attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_body_rest_{uniq_id}(
159-
int32_t K_arg,
159+
int K,
160160
int8_t *aa, int8_t *bb, int32_t *cc,
161-
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
162-
int K = K_arg;
163-
int A_stride = A_stride_arg;
164-
int B_stride = B_stride_arg;
165-
int C_stride = C_stride_arg;
166-
161+
int A_stride, int B_stride, int C_stride) {{
167162
int k_base = (K / 4) * 4;
168163
switch ( K % 4 ) {{
169164
case 1:
@@ -205,12 +200,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
205200
#endif
206201
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_loop_{uniq_id}(
207202
int8_t *aa, int8_t *bb, int32_t *cc,
208-
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
209-
int A_stride = A_stride_arg;
210-
int B_stride = B_stride_arg;
211-
int C_stride = C_stride_arg;
212-
213-
203+
int A_stride, int B_stride, int C_stride) {{
214204
for (int i = 0; i < {M}; i++) {{
215205
for (int j = 0; j < {N}; j++) {{
216206
int32_t sum = 0;
@@ -231,11 +221,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
231221
#endif
232222
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_{uniq_id}(
233223
int8_t *aa, int8_t *bb, int32_t *cc,
234-
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
235-
int A_stride = A_stride_arg;
236-
int B_stride = B_stride_arg;
237-
int C_stride = C_stride_arg;
238-
224+
int A_stride, int B_stride, int C_stride) {{
239225
int16_t bb_pad[{bb_pad_size}];
240226
int32_t retcode = 0;
241227
@@ -279,14 +265,9 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
279265
extern "C"
280266
#endif
281267
__attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_update_rest_{uniq_id}(
282-
int32_t K_arg,
268+
int K,
283269
int8_t *aa, int8_t *bb, int32_t *cc,
284-
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
285-
int K = K_arg;
286-
int A_stride = A_stride_arg;
287-
int B_stride = B_stride_arg;
288-
int C_stride = C_stride_arg;
289-
270+
int A_stride, int B_stride, int C_stride) {{
290271
int k_base = (K / 4) * 4;
291272
switch ( K % 4 ) {{
292273
case 1:
@@ -328,11 +309,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
328309
#endif
329310
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_loop_{uniq_id}(
330311
int8_t *aa, int8_t *bb, int32_t *cc,
331-
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
332-
int A_stride = A_stride_arg;
333-
int B_stride = B_stride_arg;
334-
int C_stride = C_stride_arg;
335-
312+
int A_stride, int B_stride, int C_stride) {{
336313
for (int i = 0; i < {M}; i++) {{
337314
for (int j = 0; j < {N}; j++) {{
338315
int32_t sum = 0;
@@ -350,11 +327,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
350327
#endif
351328
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_{uniq_id}(
352329
int8_t *aa, int8_t *bb, int32_t *cc,
353-
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
354-
int A_stride = A_stride_arg;
355-
int B_stride = B_stride_arg;
356-
int C_stride = C_stride_arg;
357-
330+
int A_stride, int B_stride, int C_stride) {{
358331
int16_t bb_pad[{bb_pad_size}];
359332
int32_t retcode = 0;
360333
@@ -395,14 +368,9 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
395368
extern "C"
396369
#endif
397370
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_body_rest_{uniq_id}(
398-
int32_t K_arg,
371+
int K,
399372
int16_t *aa, int16_t *bb, int32_t *cc,
400-
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
401-
int K = K_arg;
402-
int A_stride = A_stride_arg;
403-
int B_stride = B_stride_arg;
404-
int C_stride = C_stride_arg;
405-
373+
int A_stride, int B_stride, int C_stride) {{
406374
int k_base = (K / 2) * 2;
407375
for (int i = 0; i < {M}; i++) {{
408376
for (int j = 0; j < {N}; j++) {{
@@ -419,11 +387,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
419387
#endif
420388
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_loop_{uniq_id}(
421389
int16_t *aa, int16_t *bb, int32_t *cc,
422-
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
423-
int A_stride = A_stride_arg;
424-
int B_stride = B_stride_arg;
425-
int C_stride = C_stride_arg;
426-
390+
int A_stride, int B_stride, int C_stride) {{
427391
for (int i = 0; i < {M}; i++) {{
428392
for (int j = 0; j < {N}; j++) {{
429393
int32_t sum = 0;
@@ -444,11 +408,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
444408
#endif
445409
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_{uniq_id}(
446410
int16_t *aa, int16_t *bb, int32_t *cc,
447-
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
448-
int A_stride = A_stride_arg;
449-
int B_stride = B_stride_arg;
450-
int C_stride = C_stride_arg;
451-
411+
int A_stride, int B_stride, int C_stride) {{
452412
int32_t retcode = 0;
453413
454414
if ( {M} < 2 && {N} < 2 ) {{
@@ -490,14 +450,9 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
490450
extern "C"
491451
#endif
492452
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_update_rest_{uniq_id}(
493-
int32_t K_arg,
453+
int K,
494454
int16_t *aa, int16_t *bb, int32_t *cc,
495-
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
496-
int K = K_arg;
497-
int A_stride = A_stride_arg;
498-
int B_stride = B_stride_arg;
499-
int C_stride = C_stride_arg;
500-
455+
int A_stride, int B_stride, int C_stride) {{
501456
int k_base = (K / 2) * 2;
502457
for (int i = 0; i < {M}; i++) {{
503458
for (int j = 0; j < {N}; j++) {{
@@ -514,11 +469,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
514469
#endif
515470
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_loop_{uniq_id}(
516471
int16_t *aa, int16_t *bb, int32_t *cc,
517-
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
518-
int A_stride = A_stride_arg;
519-
int B_stride = B_stride_arg;
520-
int C_stride = C_stride_arg;
521-
472+
int A_stride, int B_stride, int C_stride) {{
522473
for (int i = 0; i < {M}; i++) {{
523474
for (int j = 0; j < {N}; j++) {{
524475
int32_t sum = 0;
@@ -536,11 +487,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
536487
#endif
537488
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_{uniq_id}(
538489
int16_t *aa, int16_t *bb, int32_t *cc,
539-
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
540-
int A_stride = A_stride_arg;
541-
int B_stride = B_stride_arg;
542-
int C_stride = C_stride_arg;
543-
490+
int A_stride, int B_stride, int C_stride) {{
544491
int32_t retcode = 0;
545492
546493
if ( {M} < 2 && {N} < 2 ) {{
@@ -573,7 +520,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
573520
#ifdef __cplusplus
574521
extern "C"
575522
#endif
576-
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int32_t C_stride) {{
523+
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int C_stride) {{
577524
for (int i = 0; i < {M}; i++) {{
578525
for (int j = 0; j < {N}; j++) {{
579526
cc[i*C_stride + j] = 0;

python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _body():
4646
ib = tvm.tir.ir_builder.create()
4747
ib.emit(
4848
tvm.tir.call_extern(
49-
"int32",
49+
cc.dtype,
5050
f"{func_prefix}_{uniq_id}",
5151
aa.access_ptr("r"),
5252
cc.access_ptr("w"),
@@ -59,7 +59,7 @@ def _reduce_reset():
5959
ib = tvm.tir.ir_builder.create()
6060
ib.emit(
6161
tvm.tir.call_extern(
62-
"int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0]
62+
cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0]
6363
)
6464
)
6565
return ib.get()
@@ -96,7 +96,7 @@ def max_impl(uniq_id):
9696
#endif
9797
__attribute__((always_inline)) static inline int32_t max8_reset_{uniq_id}(
9898
int8_t *res,
99-
int32_t N) {{
99+
int N) {{
100100
memset(res, (int8_t)-128, N * sizeof(*res));
101101
return 0;
102102
}}
@@ -107,9 +107,7 @@ def max_impl(uniq_id):
107107
__attribute__((always_inline)) static inline int32_t max8_loop_{uniq_id}(
108108
int8_t *arg,
109109
int8_t *res,
110-
int32_t N_arg) {{
111-
int N = N_arg;
112-
110+
int N) {{
113111
for ( int i = 0; i < N; ++ i )
114112
if ( arg[i] > res[i] )
115113
res[i] = arg[i];
@@ -122,8 +120,7 @@ def max_impl(uniq_id):
122120
__attribute__((always_inline)) static inline int32_t max8_{uniq_id}(
123121
int8_t *arg,
124122
int8_t *res,
125-
int32_t N_arg) {{
126-
int N = N_arg;
123+
int N) {{
127124
int32_t *parg32, *pres32;
128125
int una_arg = (int32_t)arg & 0x3, una_res = (int32_t)res & 0x3;
129126
int32_t retcode = 0;

python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -390,13 +390,8 @@ def insert_lines(lines):
390390
#define {function_name.upper()}_EXISTS
391391
#include <arm_acle.h>
392392
__attribute__((always_inline)) static inline int32_t {function_name}(
393-
int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
394-
int32_t *bias, int32_t *scale
393+
int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
395394
) {{
396-
int32_t *output = output_arg;
397-
int32_t *tensor = tensor_arg;
398-
int32_t *kernel = kernel_arg;
399-
400395
{_init_biased_accumulators(num_outputs)}
401396
402397
{insert_lines(load_tensor_lines)}

src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
4646
CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str, devices);
4747
}
4848

49+
/*!
50+
* \brief Emit code that offloads a subgraph to the Cortex-M
51+
*
52+
* \return string of code that offloads a subgraph to the Cortex-M
53+
*/
54+
void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); }
55+
4956
private:
5057
/*! * \brief Enable storing the last error */
5158
bool debug_last_error;
@@ -568,11 +575,11 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
568575
bool emit_fwd_func_decl = false;
569576
bool debug_last_error = GetCompilerAttrs()->debug_last_error;
570577
CodeGenCMSISNN codegen;
578+
Array<String> function_names;
571579
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), debug_last_error);
572-
573-
std::vector<std::pair<tvm::GlobalVar, tvm::PrimFunc>> funcs;
574-
for (auto [gvar, base_func] : mod->functions) {
575-
funcs.push_back({gvar, Downcast<PrimFunc>(base_func)});
580+
std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
581+
for (auto kv : mod->functions) {
582+
funcs.push_back(kv);
576583
}
577584

578585
std::sort(funcs.begin(), funcs.end(),
@@ -587,16 +594,13 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
587594
return name_hint_a < name_hint_b;
588595
});
589596

590-
for (auto [gvar, prim_func] : funcs) {
591-
codegen.AddFunction(gvar, prim_func);
597+
for (auto kv : funcs) {
598+
auto prim_func = Downcast<PrimFunc>(kv.second);
599+
auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
600+
function_names.push_back(global_symbol.value());
601+
codegen.AddFunction(prim_func);
592602
}
593603
std::string code = codegen.Finish();
594-
595-
Array<String> function_names;
596-
for (auto [gvar, prim_func] : funcs) {
597-
function_names.push_back(codegen.GetFunctionName(gvar));
598-
}
599-
600604
return codegen::CSourceModuleCreate(code, "c", function_names);
601605
}
602606

src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -49,30 +49,16 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
4949
bool emit_asserts = false;
5050
bool emit_fwd_func_decl = false;
5151
CodeGenExampleTargetHook codegen;
52-
52+
Array<String> function_names;
5353
std::unordered_set<std::string> devices;
5454
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices);
55-
56-
Map<GlobalVar, PrimFunc> functions;
57-
for (auto [gvar, base_func] : mod->functions) {
58-
auto prim_func = Downcast<PrimFunc>(base_func);
59-
functions.Set(gvar, prim_func);
60-
}
61-
62-
for (auto [gvar, prim_func] : functions) {
63-
codegen.DeclareFunction(gvar, prim_func);
64-
}
65-
for (auto [gvar, prim_func] : functions) {
66-
codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl);
55+
for (auto kv : mod->functions) {
56+
auto prim_func = Downcast<PrimFunc>(kv.second);
57+
auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
58+
function_names.push_back(global_symbol.value());
59+
codegen.AddFunction(prim_func);
6760
}
68-
6961
std::string code = codegen.Finish();
70-
71-
Array<String> function_names;
72-
for (auto [gvar, prim_func] : functions) {
73-
function_names.push_back(codegen.GetFunctionName(gvar));
74-
}
75-
7662
return codegen::CSourceModuleCreate(code, "c", function_names);
7763
}
7864

0 commit comments

Comments
 (0)