-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Target] Support CUDA device function calls #18055
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
super helpful enhancement, Thanks! |
4bfdbf7 to
5a8fbd4
Compare
| return a + b | ||
|
|
||
| @T.prim_func | ||
| def main( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think right now it's quite implicit to determine which function is the kernel function and which function is the device function. It might be clearer if we can mark @T.prim_func(kind="device") explicitly for device functions.
Moreover, we can enhance by adding a test case where all functions are not wrapped by Module, and instead of compiling the Module, we compile the kernel function directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an interesting pt, added some followup comment on #18055 (comment) about distinctions before/after SplitHostDevice, maybe it is fine before SplitHostDevice(in this case), but would be good to clarify in comment
| device_mod_dict[target] = tvm.IRModule(funcs, attrs=device_mod.attrs) | ||
| target = func.attrs.get("target", None) | ||
| target_str = str(target) if target is not None else "" | ||
| target_str2target[target_str] = target # This might be overridden by the last one |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to make sure in which cases different Target obects might have the same string representations target_str.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now str uniquely maps a target so it is ok here, but good to document such invariance and like @Hzfengsy commented, we can fix after target hash is supported
|
after reading the comments so far on host/device function info split and the compiler phases:
After we enable the compiler to handle device function, one thing we first need to ensure is what is the behavior after S1. Would be useful to clarify in the PR with comments. Summarizing the logic so far:
Here is an example of such case: @I.ir_module
class Module:
@T.prim_func(private=True)
def add(a: T.float32, b: T.float32) -> T.float32:
return a + b
@T.prim_func
def main(
A: T.Buffer((1024, 1024), "float32"),
B: T.Buffer((1024, 1024), "float32"),
C: T.Buffer((1024, 1024), "float32"),
):
# bound temp var in host side
temp_var = T.float32()
with T.LetStmt(
Module.add(T.float32(1), T.float32(2))
var=temp_var,
):
for bx in T.thread_binding(1024, "blockIdx.x"):
for tx in T.thread_binding(1024, "threadIdx.x"):
C[bx, tx] = Module.add(A[bx, tx], B[bx, tx]) + temp_varBecause of the implicitness, we may need to cross check the current bebehavior of SplitHostDevice, for rare cases where say both host and device calls the same function: in such cases we may either In both cases, would be good to enhance splithostdevice testcases to ensure target field is clear after S1 |
|
Thanks @tqchen and @Kathryn-cat 's valuable comments, will refactor the PR to enhance SplitHostDevice systematically. |
|
@tqchen @Kathryn-cat I've updated a version to detect if a function is from hthe ost side and device side at an early stage (in BindTarget pass), here is an example of a mixture call to the same function @I.ir_module
class Module:
@T.prim_func(private=True)
def add(a: T.int32, b: T.int32) -> T.int32:
return a + b
@T.prim_func
def main(
A: T.Buffer((128, 128), "int32"),
B: T.Buffer((128, 128), "int32"),
C: T.Buffer((128, 128), "int32"),
):
T.func_attr({"global_symbol": "main"})
length: T.int32 = Module.add(64, 64) # Call from host
for bx in T.thread_binding(length, "blockIdx.x"):
for tx in T.thread_binding(length, "threadIdx.x"):
C[bx, tx] = Module.add(A[bx, tx], B[bx, tx]) # Call from devicePlease review again when you have time |
65a1be6 to
db18f3e
Compare
This PR introduces support for device call compilation in TVM by enhancing the BindTarget pass to properly handle functions called from both host and device contexts. The key improvement is the ability to automatically create host-specific duplicates of functions that are called from both host and device code, ensuring proper target binding for heterogeneous compilation.
- **Function Classification**: Analyzes call patterns to identify functions called from host vs device contexts
- **Smart Target Binding**: Automatically binds appropriate targets based on calling context:
- Functions called only from host → host target
- Functions called only from device → device target
- Functions called from both → device target + host duplicate
- **Call Site Updates**: Updates call sites in externally exposed functions to use appropriate duplicates
- Improved device function extraction and kernel generation
- Better handling of error propagation for different device types
- Enhanced buffer declaration and parameter management
- Support for `__device__` function calls in CUDA kernels
- Proper function signature generation for device functions
- Enhanced calling convention handling
- Updated build pipeline to handle device call compilation
- Improved target-specific compilation logic
The following example demonstrates how the BindTarget pass handles functions called from both host and device contexts:
```python
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(a: T.int32, b: T.int32) -> T.int32:
return a + b
@T.prim_func
def main(
A: T.Buffer((128, 128), "int32"),
B: T.Buffer((128, 128), "int32"),
C: T.Buffer((128, 128), "int32"),
):
T.func_attr({"global_symbol": "main"})
length: T.int32 = Module.add(64, 64) # Call from host
for bx in T.thread_binding(length, "blockIdx.x"):
for tx in T.thread_binding(length, "threadIdx.x"):
C[bx, tx] = Module.add(A[bx, tx], B[bx, tx]) # Call from device
```
After applying `BindTarget(cuda, host="llvm")`, the pass automatically:
1. Creates a device version of `add` with CUDA target
2. Creates a host duplicate `add_host` with LLVM target
3. Updates the main function to call `add_host` from host context and `add` from device context
This enables seamless compilation of mixed host/device code while maintaining proper target-specific optimizations and code generation.
- **Automatic Target Binding**: No manual target annotation required for most use cases
- **Heterogeneous Compilation**: Proper support for functions called from multiple contexts
- **Code Reuse**: Shared functions can be called from both host and device without duplication
- **Performance**: Maintains target-specific optimizations for each context
- **Developer Experience**: Simplifies writing mixed host/device code
The implementation is backward compatible and integrates seamlessly with existing TVM compilation pipelines.
| new_mod->Update(gvar, | ||
| WithAttr(std::move(prim_func), tvm::attr::kTarget, target_without_host)); | ||
| } else { | ||
| // Rule 4.4: Not called by any context |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to document a bit on the behavior of rule 4.4 "the function is not called by any host or device", currently it defaults to target_without_host, would be helpful to document a bit on the desired behavior and what are likely issues that might need check further.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Kathryn-cat for pointing it out. Actually, I'm not sure why the current behaviour works like this. Also, there is a testcase for it. To keep the project working well, I choose to keep the current behaviour. However, IMO, if a function is not called by any global function, it will be a useless function, so somehow we can directly remove it from the IRModule.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I think we can keep the current behavior for now.
|
@Hzfengsy LGTM! Just added a small comment and I think we're good to go. |
[TIR][Target] Support device call compilation
This PR introduces support for device call compilation in TVM by enhancing the BindTarget pass to properly handle functions called from both host and device contexts. The key improvement is the ability to automatically create host-specific duplicates of functions that are called from both host and device code, ensuring proper target binding for heterogeneous compilation.
- **Function Classification**: Analyzes call patterns to identify functions called from host vs device contexts
- **Smart Target Binding**: Automatically binds appropriate targets based on calling context:
- Functions called only from host → host target
- Functions called only from device → device target
- Functions called from both → device target + host duplicate
- **Call Site Updates**: Updates call sites in externally exposed functions to use appropriate duplicates
- Improved device function extraction and kernel generation
- Better handling of error propagation for different device types
- Enhanced buffer declaration and parameter management
- Support for `__device__` function calls in CUDA kernels
- Proper function signature generation for device functions
- Enhanced calling convention handling
- Updated build pipeline to handle device call compilation
- Improved target-specific compilation logic
The following example demonstrates how the BindTarget pass handles functions called from both host and device contexts:
```python
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(a: T.int32, b: T.int32) -> T.int32:
return a + b
@T.prim_func
def main(
A: T.Buffer((128, 128), "int32"),
B: T.Buffer((128, 128), "int32"),
C: T.Buffer((128, 128), "int32"),
):
T.func_attr({"global_symbol": "main"})
length: T.int32 = Module.add(64, 64) # Call from host
for bx in T.thread_binding(length, "blockIdx.x"):
for tx in T.thread_binding(length, "threadIdx.x"):
C[bx, tx] = Module.add(A[bx, tx], B[bx, tx]) # Call from device
```
After applying `BindTarget(cuda, host="llvm")`, the pass automatically:
1. Creates a device version of `add` with CUDA target
2. Creates a host duplicate `add_host` with LLVM target
3. Updates the main function to call `add_host` from host context and `add` from device context
This enables seamless compilation of mixed host/device code while maintaining proper target-specific optimizations and code generation.
- **Automatic Target Binding**: No manual target annotation required for most use cases
- **Heterogeneous Compilation**: Proper support for functions called from multiple contexts
- **Code Reuse**: Shared functions can be called from both host and device without duplication
- **Performance**: Maintains target-specific optimizations for each context
- **Developer Experience**: Simplifies writing mixed host/device code
The implementation is backward compatible and integrates seamlessly with existing TVM compilation pipelines.
This PR introduces support for device call compilation in TVM by enhancing the BindTarget pass to properly handle functions called from both host and device contexts. The key improvement is the ability to automatically create host-specific duplicates of functions that are called from both host and device code, ensuring proper target binding for heterogeneous compilation.
The following example demonstrates how the BindTarget pass handles functions called from both host and device contexts:
After applying
BindTarget(cuda, host="llvm"), the pass automatically:addwith CUDA targetadd_hostwith LLVM targetadd_hostfrom host context andaddfrom device contextThis enables seamless compilation of mixed host/device code while maintaining proper target-specific optimizations and code generation.
The implementation is backward compatible and integrates seamlessly with existing TVM compilation pipelines.