Skip to content

Commit acdc164

Browse files
authored
[Target] Support CUDA device function calls (#18055)
[TIR][Target] Support device call compilation This PR introduces support for device call compilation in TVM by enhancing the BindTarget pass to properly handle functions called from both host and device contexts. The key improvement is the ability to automatically create host-specific duplicates of functions that are called from both host and device code, ensuring proper target binding for heterogeneous compilation. - **Function Classification**: Analyzes call patterns to identify functions called from host vs device contexts - **Smart Target Binding**: Automatically binds appropriate targets based on calling context: - Functions called only from host → host target - Functions called only from device → device target - Functions called from both → device target + host duplicate - **Call Site Updates**: Updates call sites in externally exposed functions to use appropriate duplicates - Improved device function extraction and kernel generation - Better handling of error propagation for different device types - Enhanced buffer declaration and parameter management - Support for `__device__` function calls in CUDA kernels - Proper function signature generation for device functions - Enhanced calling convention handling - Updated build pipeline to handle device call compilation - Improved target-specific compilation logic The following example demonstrates how the BindTarget pass handles functions called from both host and device contexts: ```python @I.ir_module class Module: @T.prim_func(private=True) def add(a: T.int32, b: T.int32) -> T.int32: return a + b @T.prim_func def main( A: T.Buffer((128, 128), "int32"), B: T.Buffer((128, 128), "int32"), C: T.Buffer((128, 128), "int32"), ): T.func_attr({"global_symbol": "main"}) length: T.int32 = Module.add(64, 64) # Call from host for bx in T.thread_binding(length, "blockIdx.x"): for tx in T.thread_binding(length, "threadIdx.x"): C[bx, tx] = Module.add(A[bx, tx], B[bx, tx]) # Call from device ``` After applying `BindTarget(cuda, host="llvm")`, the pass automatically: 1. Creates a device version of `add` with CUDA target 2. Creates a host duplicate `add_host` with LLVM target 3. Updates the main function to call `add_host` from host context and `add` from device context This enables seamless compilation of mixed host/device code while maintaining proper target-specific optimizations and code generation. - **Automatic Target Binding**: No manual target annotation required for most use cases - **Heterogeneous Compilation**: Proper support for functions called from multiple contexts - **Code Reuse**: Shared functions can be called from both host and device without duplication - **Performance**: Maintains target-specific optimizations for each context - **Developer Experience**: Simplifies writing mixed host/device code The implementation is backward compatible and integrates seamlessly with existing TVM compilation pipelines.
1 parent 458b0ab commit acdc164

File tree

10 files changed

+597
-70
lines changed

10 files changed

+597
-70
lines changed

python/tvm/tir/build.py

Lines changed: 78 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717

1818
# pylint: disable=invalid-name
1919
"""The build utils in python."""
20-
from typing import Union, Optional, Dict
21-
import enum
20+
from typing import Union, Optional, Dict, Tuple
2221

2322
import tvm
2423
from tvm import ir
@@ -28,44 +27,95 @@
2827
from tvm.target import Target
2928

3029

31-
def split_host_device_mods(mod):
30+
def split_host_device_mods(mod: IRModule) -> Tuple[IRModule, Dict[Target, IRModule]]:
3231
"""Split an IRModule into host and device modules.
3332
33+
This function takes an IRModule containing functions with different target attributes
34+
and separates them into host (CPU) and device (GPU/accelerator) modules. Functions
35+
are categorized based on their target attribute in func_attr.
36+
3437
Parameters
3538
----------
3639
mod : tvm.IRModule
37-
The input module to split
40+
The input module to split.
41+
The module should contain functions with target attributes in their func_attr.
42+
Functions with "cpu" in their target string are considered host functions,
43+
while others are considered device functions.
3844
3945
Returns
4046
-------
4147
host_mod : tvm.IRModule
42-
The module containing host functions
48+
The module containing host functions (CPU-targeted functions)
4349
device_mod_dict : Dict[Target, tvm.IRModule]
44-
A dict mapping targets to device modules
50+
A dict mapping targets to device modules. Each device module contains
51+
functions targeting the same device (e.g., CUDA GPU, OpenCL, etc.)
52+
53+
Examples
54+
--------
55+
Given an IRModule with the following functions:
56+
57+
.. code-block:: python
58+
59+
@I.ir_module
60+
class Module:
61+
@T.prim_func(private=True)
62+
def add(a: T.int32, b: T.int32) -> T.int32:
63+
T.func_attr({"target": T.target({"arch": "sm_90", "keys": ["cuda", "gpu"],
64+
"kind": "cuda", "max_num_threads": 1024}))
65+
return a + b
66+
67+
@T.prim_func(private=True)
68+
def add_host(a: T.int32, b: T.int32) -> T.int32:
69+
T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "c"}))
70+
return a + b
71+
72+
@T.prim_func
73+
def main_kernel(A: T.handle, B: T.handle, C: T.handle, length: T.int32):
74+
T.func_attr({"target": T.target({"arch": "sm_90", "keys": ["cuda", "gpu"],
75+
"kind": "cuda"}),
76+
"calling_conv": 2, # kDeviceKernelLaunch for device kernels
77+
"tir.is_global_func": True})
78+
# ... kernel implementation
79+
80+
@T.prim_func
81+
def main(self_handle: T.handle, args: T.handle, num_args: T.int32, result: T.handle):
82+
T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "c"}),
83+
"calling_conv": 1, # kCPackedFunc for entry functions
84+
"tir.is_entry_func": True})
85+
# ... main function implementation
86+
87+
The function will return:
88+
- host_mod: Contains `add_host` and `main` functions (CPU targets)
89+
- device_mod_dict: Contains a CUDA module with `add` and `main_kernel` functions
90+
91+
Notes
92+
-----
93+
- Functions are categorized based on string matching of their target attribute
94+
- Functions with "cpu" in the target string are considered host functions
95+
- Device functions are grouped by their target to create separate modules
96+
- The function uses string-based target matching due to target hash limitations
97+
- All functions must have a `calling_conv` attribute in their func_attr:
98+
- Private helper functions (private=True): use `calling_conv: 0` (kDefault, by default)
99+
- Public entry functions: use `calling_conv: 1` (kCPackedFunc)
100+
- Device kernel functions: use `calling_conv: 2` (kDeviceKernelLaunch)
45101
"""
46102

47-
class CallConv(enum.IntEnum):
48-
"""Enum representing different calling conventions.
49-
Corresponds to the C++ tvm::ir::CallingConv enum.
50-
"""
51-
52-
kDefault = 0
53-
kCPackedFunc = 1
54-
kDeviceKernelLaunch = 2
55-
56-
host_mod = tvm.tir.transform.Filter(
57-
lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault))
58-
!= int(CallConv.kDeviceKernelLaunch)
59-
)(mod)
60-
device_mod = tvm.tir.transform.Filter(
61-
lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault))
62-
== int(CallConv.kDeviceKernelLaunch)
63-
)(mod)
64-
device_mod_dict = {}
103+
host_mod = tvm.tir.transform.Filter(lambda f: "cpu" in str(f.attrs.get("target", "cpu")))(mod)
104+
device_mod = tvm.tir.transform.Filter(lambda f: "cpu" not in str(f.attrs.get("target", "cpu")))(
105+
mod
106+
)
107+
# TODO(syfeng): Here we use str as key since target hash is not correct
108+
target_str2target = {}
109+
device_func_dict = {}
110+
device_mod_dict: Dict[Target, IRModule] = {}
65111
for gv, func in device_mod.functions.items():
66-
device_mod_dict.setdefault(func.attrs.get("target", None), dict()).update({gv: func})
67-
for target, funcs in device_mod_dict.items():
68-
device_mod_dict[target] = tvm.IRModule(funcs, attrs=device_mod.attrs)
112+
target = func.attrs.get("target", None)
113+
target_str = str(target) if target is not None else ""
114+
target_str2target[target_str] = target # This might be overridden by the last one
115+
device_func_dict.setdefault(target_str, dict()).update({gv: func})
116+
for target_str in target_str2target.keys():
117+
target = target_str2target[target_str]
118+
device_mod_dict[target] = tvm.IRModule(device_func_dict[target_str], attrs=device_mod.attrs)
69119
return host_mod, device_mod_dict
70120

71121

@@ -162,7 +212,7 @@ def build(
162212
# Step 3: Bind the target to the input module
163213
mod = tvm.tir.transform.BindTarget(target_to_bind)(mod)
164214

165-
# Step 4: Apply the tir pipeline
215+
# Step 4: Apply the tir pipeline
166216
if pipeline is not None:
167217
# custom pipeline
168218
if isinstance(pipeline, str):

src/target/build_common.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ inline std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
6666
}
6767
}
6868
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
69-
fmap[static_cast<std::string>(global_symbol.value())] = info;
69+
if (global_symbol) {
70+
fmap[static_cast<std::string>(global_symbol.value())] = info;
71+
}
7072
}
7173
return fmap;
7274
}

src/target/opt/build_cuda_on.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,12 @@ runtime::Module BuildCUDA(IRModule mod, Target target) {
134134
for (auto [gvar, base_func] : mod->functions) {
135135
ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
136136
auto prim_func = Downcast<PrimFunc>(base_func);
137-
auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
138-
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
139-
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
137+
auto calling_conv =
138+
prim_func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDefault));
139+
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch ||
140+
calling_conv == CallingConv::kDefault)
141+
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch or "
142+
"CallingConv::kDefault";
140143
functions.Set(gvar, prim_func);
141144
}
142145

src/target/source/codegen_cuda.cc

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,19 @@ void CodeGenCUDA::Init(bool output_ssa) {
140140
ICHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state);
141141
}
142142

143-
void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ "; }
143+
void CodeGenCUDA::PrintFunctionSignature(const String& function_name, const PrimFunc& func,
144+
std::ostream& os) {
145+
auto calling_conv =
146+
func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDefault));
147+
if (calling_conv == CallingConv::kDeviceKernelLaunch) {
148+
os << "extern \"C\" __global__ ";
149+
} else if (calling_conv == CallingConv::kDefault) {
150+
os << "extern \"C\" __device__ ";
151+
} else {
152+
LOG(FATAL) << "Unsupported calling convention for cuda codegen: " << calling_conv;
153+
}
154+
CodeGenC::PrintFunctionSignature(function_name, func, os);
155+
}
144156

145157
class ThreadIdxExtractor : public tir::StmtVisitor {
146158
private:

src/target/source/codegen_cuda.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ class CodeGenCUDA final : public CodeGenC {
4646
enable_fp4_ || need_math_constants_h_ || need_mma_h_);
4747
}
4848
// override behavior
49-
void PrintFuncPrefix(std::ostream& os) final;
49+
void PrintFunctionSignature(const String& function_name, const PrimFunc& func,
50+
std::ostream& os) final;
5051
void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final; // NOLINT(*)
5152
void VisitStmt_(const ForNode* op) final;
5253
void PrintStorageSync(const CallNode* op) final;

0 commit comments

Comments
 (0)