|
17 | 17 |
|
18 | 18 | # pylint: disable=invalid-name |
19 | 19 | """The build utils in python.""" |
20 | | -from typing import Union, Optional, Dict |
21 | | -import enum |
| 20 | +from typing import Union, Optional, Dict, Tuple |
22 | 21 |
|
23 | 22 | import tvm |
24 | 23 | from tvm import ir |
|
28 | 27 | from tvm.target import Target |
29 | 28 |
|
30 | 29 |
|
31 | | -def split_host_device_mods(mod): |
| 30 | +def split_host_device_mods(mod: IRModule) -> Tuple[IRModule, Dict[Target, IRModule]]: |
32 | 31 | """Split an IRModule into host and device modules. |
33 | 32 |
|
| 33 | + This function takes an IRModule containing functions with different target attributes |
| 34 | + and separates them into host (CPU) and device (GPU/accelerator) modules. Functions |
| 35 | + are categorized based on their target attribute in func_attr. |
| 36 | +
|
34 | 37 | Parameters |
35 | 38 | ---------- |
36 | 39 | mod : tvm.IRModule |
37 | | - The input module to split |
| 40 | + The input module to split. |
| 41 | + The module should contain functions with target attributes in their func_attr. |
| 42 | + Functions with "cpu" in their target string are considered host functions, |
| 43 | + while others are considered device functions. |
38 | 44 |
|
39 | 45 | Returns |
40 | 46 | ------- |
41 | 47 | host_mod : tvm.IRModule |
42 | | - The module containing host functions |
| 48 | + The module containing host functions (CPU-targeted functions) |
43 | 49 | device_mod_dict : Dict[Target, tvm.IRModule] |
44 | | - A dict mapping targets to device modules |
| 50 | + A dict mapping targets to device modules. Each device module contains |
| 51 | + functions targeting the same device (e.g., CUDA GPU, OpenCL, etc.) |
| 52 | +
|
| 53 | + Examples |
| 54 | + -------- |
| 55 | + Given an IRModule with the following functions: |
| 56 | +
|
| 57 | + .. code-block:: python |
| 58 | +
|
| 59 | + @I.ir_module |
| 60 | + class Module: |
| 61 | + @T.prim_func(private=True) |
| 62 | + def add(a: T.int32, b: T.int32) -> T.int32: |
| 63 | + T.func_attr({"target": T.target({"arch": "sm_90", "keys": ["cuda", "gpu"], |
| 64 | + "kind": "cuda", "max_num_threads": 1024})) |
| 65 | + return a + b |
| 66 | +
|
| 67 | + @T.prim_func(private=True) |
| 68 | + def add_host(a: T.int32, b: T.int32) -> T.int32: |
| 69 | + T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "c"})) |
| 70 | + return a + b |
| 71 | +
|
| 72 | + @T.prim_func |
| 73 | + def main_kernel(A: T.handle, B: T.handle, C: T.handle, length: T.int32): |
| 74 | + T.func_attr({"target": T.target({"arch": "sm_90", "keys": ["cuda", "gpu"], |
| 75 | + "kind": "cuda"}), |
| 76 | + "calling_conv": 2, # kDeviceKernelLaunch for device kernels |
| 77 | + "tir.is_global_func": True}) |
| 78 | + # ... kernel implementation |
| 79 | +
|
| 80 | + @T.prim_func |
| 81 | + def main(self_handle: T.handle, args: T.handle, num_args: T.int32, result: T.handle): |
| 82 | + T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "c"}), |
| 83 | + "calling_conv": 1, # kCPackedFunc for entry functions |
| 84 | + "tir.is_entry_func": True}) |
| 85 | + # ... main function implementation |
| 86 | +
|
| 87 | + The function will return: |
| 88 | + - host_mod: Contains `add_host` and `main` functions (CPU targets) |
| 89 | + - device_mod_dict: Contains a CUDA module with `add` and `main_kernel` functions |
| 90 | +
|
| 91 | + Notes |
| 92 | + ----- |
| 93 | + - Functions are categorized based on string matching of their target attribute |
| 94 | + - Functions with "cpu" in the target string are considered host functions |
| 95 | + - Device functions are grouped by their target to create separate modules |
| 96 | + - The function uses string-based target matching due to target hash limitations |
| 97 | + - All functions must have a `calling_conv` attribute in their func_attr: |
| 98 | + - Private helper functions (private=True): use `calling_conv: 0` (kDefault, by default) |
| 99 | + - Public entry functions: use `calling_conv: 1` (kCPackedFunc) |
| 100 | + - Device kernel functions: use `calling_conv: 2` (kDeviceKernelLaunch) |
45 | 101 | """ |
46 | 102 |
|
47 | | - class CallConv(enum.IntEnum): |
48 | | - """Enum representing different calling conventions. |
49 | | - Corresponds to the C++ tvm::ir::CallingConv enum. |
50 | | - """ |
51 | | - |
52 | | - kDefault = 0 |
53 | | - kCPackedFunc = 1 |
54 | | - kDeviceKernelLaunch = 2 |
55 | | - |
56 | | - host_mod = tvm.tir.transform.Filter( |
57 | | - lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault)) |
58 | | - != int(CallConv.kDeviceKernelLaunch) |
59 | | - )(mod) |
60 | | - device_mod = tvm.tir.transform.Filter( |
61 | | - lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault)) |
62 | | - == int(CallConv.kDeviceKernelLaunch) |
63 | | - )(mod) |
64 | | - device_mod_dict = {} |
| 103 | + host_mod = tvm.tir.transform.Filter(lambda f: "cpu" in str(f.attrs.get("target", "cpu")))(mod) |
| 104 | + device_mod = tvm.tir.transform.Filter(lambda f: "cpu" not in str(f.attrs.get("target", "cpu")))( |
| 105 | + mod |
| 106 | + ) |
| 107 | + # TODO(syfeng): Here we use str as key since target hash is not correct |
| 108 | + target_str2target = {} |
| 109 | + device_func_dict = {} |
| 110 | + device_mod_dict: Dict[Target, IRModule] = {} |
65 | 111 | for gv, func in device_mod.functions.items(): |
66 | | - device_mod_dict.setdefault(func.attrs.get("target", None), dict()).update({gv: func}) |
67 | | - for target, funcs in device_mod_dict.items(): |
68 | | - device_mod_dict[target] = tvm.IRModule(funcs, attrs=device_mod.attrs) |
| 112 | + target = func.attrs.get("target", None) |
| 113 | + target_str = str(target) if target is not None else "" |
| 114 | + target_str2target[target_str] = target # This might be overridden by the last one |
| 115 | + device_func_dict.setdefault(target_str, dict()).update({gv: func}) |
| 116 | + for target_str in target_str2target.keys(): |
| 117 | + target = target_str2target[target_str] |
| 118 | + device_mod_dict[target] = tvm.IRModule(device_func_dict[target_str], attrs=device_mod.attrs) |
69 | 119 | return host_mod, device_mod_dict |
70 | 120 |
|
71 | 121 |
|
@@ -162,7 +212,7 @@ def build( |
162 | 212 | # Step 3: Bind the target to the input module |
163 | 213 | mod = tvm.tir.transform.BindTarget(target_to_bind)(mod) |
164 | 214 |
|
165 | | - # Step 4: Apply the tir pipeline |
| 215 | + # Step 4: Apply the tir pipeline |
166 | 216 | if pipeline is not None: |
167 | 217 | # custom pipeline |
168 | 218 | if isinstance(pipeline, str): |
|
0 commit comments