Skip to content

Conversation

@Hzfengsy
Copy link
Member

@Hzfengsy Hzfengsy commented Jun 12, 2025

This PR introduces support for device call compilation in TVM by enhancing the BindTarget pass to properly handle functions called from both host and device contexts. The key improvement is the ability to automatically create host-specific duplicates of functions that are called from both host and device code, ensuring proper target binding for heterogeneous compilation.

  • Function Classification: Analyzes call patterns to identify functions called from host vs device contexts
  • Smart Target Binding: Automatically binds appropriate targets based on calling context:
    • Functions called only from host → host target
    • Functions called only from device → device target
    • Functions called from both → device target + host duplicate
  • Call Site Updates: Updates call sites in externally exposed functions to use appropriate duplicates

The following example demonstrates how the BindTarget pass handles functions called from both host and device contexts:

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(a: T.int32, b: T.int32) -> T.int32:
        return a + b

    @T.prim_func
    def main(
        A: T.Buffer((128, 128), "int32"),
        B: T.Buffer((128, 128), "int32"),
        C: T.Buffer((128, 128), "int32"),
    ):
        T.func_attr({"global_symbol": "main"})
        length: T.int32 = Module.add(64, 64)  # Call from host
        for bx in T.thread_binding(length, "blockIdx.x"):
            for tx in T.thread_binding(length, "threadIdx.x"):
                C[bx, tx] = Module.add(A[bx, tx], B[bx, tx])  # Call from device

After applying BindTarget(cuda, host="llvm"), the pass automatically:

  1. Creates a device version of add with CUDA target
  2. Creates a host duplicate add_host with LLVM target
  3. Updates the main function to call add_host from host context and add from device context

This enables seamless compilation of mixed host/device code while maintaining proper target-specific optimizations and code generation.

  • Automatic Target Binding: No manual target annotation required for most use cases
  • Heterogeneous Compilation: Proper support for functions called from multiple contexts
  • Code Reuse: Shared functions can be called from both host and device without duplication
  • Performance: Maintains target-specific optimizations for each context
  • Developer Experience: Simplifies writing mixed host/device code

The implementation is backward compatible and integrates seamlessly with existing TVM compilation pipelines.

@LeiWang1999
Copy link
Contributor

super helpful enhancement, Thanks!

@Hzfengsy Hzfengsy force-pushed the device_call branch 3 times, most recently from 4bfdbf7 to 5a8fbd4 Compare June 13, 2025 04:35
@tqchen tqchen requested a review from spectrometerHBH June 13, 2025 18:15
return a + b

@T.prim_func
def main(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think right now it's quite implicit to determine which function is the kernel function and which function is the device function. It might be clearer if we can mark @T.prim_func(kind="device") explicitly for device functions.

Moreover, we can enhance by adding a test case where all functions are not wrapped by Module, and instead of compiling the Module, we compile the kernel function directly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting pt, added some followup comment on #18055 (comment) about distinctions before/after SplitHostDevice, maybe it is fine before SplitHostDevice(in this case), but would be good to clarify in comment

device_mod_dict[target] = tvm.IRModule(funcs, attrs=device_mod.attrs)
target = func.attrs.get("target", None)
target_str = str(target) if target is not None else ""
target_str2target[target_str] = target # This might be overridden by the last one
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to make sure in which cases different Target obects might have the same string representations target_str.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now str uniquely maps a target so it is ok here, but good to document such invariance and like @Hzfengsy commented, we can fix after target hash is supported

@tqchen
Copy link
Member

tqchen commented Jun 14, 2025

after reading the comments so far on host/device function info split and the compiler phases:

  • S0: In the beginning(before SplitHostDevice), we don't distinguish host/device function, a function can contain kernels
  • S1: The host/device function split becomes clear after the SplitHostDevice pass. currently in the case of single device launch:
    • global kernel are annotated as DeviceKernelLaunch calling conv
    • host ones are annotated as others

After we enable the compiler to handle device function, one thing we first need to ensure is what is the behavior after S1. Would be useful to clarify in the PR with comments.

Summarizing the logic so far:

  • Before S0 seems the decision is to not distinguish between host/device function and implicit
  • Such distinction should become clear after S1, by checking the target annotation of each function that marks the default convention.

Here is an example of such case:

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(a: T.float32, b: T.float32) -> T.float32:
        return a + b

    @T.prim_func
    def main(
        A: T.Buffer((1024, 1024), "float32"),
        B: T.Buffer((1024, 1024), "float32"),
        C: T.Buffer((1024, 1024), "float32"),
    ):
        # bound temp var in host side
        temp_var = T.float32()
        with T.LetStmt(
                Module.add(T.float32(1), T.float32(2))
                var=temp_var,
         ):
          for bx in T.thread_binding(1024, "blockIdx.x"):
              for tx in T.thread_binding(1024, "threadIdx.x"):
                  C[bx, tx] = Module.add(A[bx, tx], B[bx, tx]) + temp_var

Because of the implicitness, we may need to cross check the current bebehavior of SplitHostDevice, for rare cases where say both host and device calls the same function: in such cases we may either
- S0a: place constraint and report an error
- S0b: have SplitHostDevice pass manually duplicate such function and mark the target

In both cases, would be good to enhance splithostdevice testcases to ensure target field is clear after S1

@Hzfengsy
Copy link
Member Author

Thanks @tqchen and @Kathryn-cat 's valuable comments, will refactor the PR to enhance SplitHostDevice systematically.

@Hzfengsy Hzfengsy marked this pull request as draft June 17, 2025 02:26
@Hzfengsy Hzfengsy marked this pull request as ready for review July 9, 2025 03:02
@Hzfengsy
Copy link
Member Author

Hzfengsy commented Jul 9, 2025

@tqchen @Kathryn-cat I've updated a version to detect if a function is from hthe ost side and device side at an early stage (in BindTarget pass), here is an example of a mixture call to the same function

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(a: T.int32, b: T.int32) -> T.int32:
        return a + b

    @T.prim_func
    def main(
        A: T.Buffer((128, 128), "int32"),
        B: T.Buffer((128, 128), "int32"),
        C: T.Buffer((128, 128), "int32"),
    ):
        T.func_attr({"global_symbol": "main"})
        length: T.int32 = Module.add(64, 64)  # Call from host
        for bx in T.thread_binding(length, "blockIdx.x"):
            for tx in T.thread_binding(length, "threadIdx.x"):
                C[bx, tx] = Module.add(A[bx, tx], B[bx, tx])  # Call from device

Please review again when you have time

@Hzfengsy Hzfengsy force-pushed the device_call branch 2 times, most recently from 65a1be6 to db18f3e Compare July 9, 2025 07:11
This PR introduces support for device call compilation in TVM by enhancing the BindTarget pass to properly handle functions called from both host and device contexts. The key improvement is the ability to automatically create host-specific duplicates of functions that are called from both host and device code, ensuring proper target binding for heterogeneous compilation.

- **Function Classification**: Analyzes call patterns to identify functions called from host vs device contexts
- **Smart Target Binding**: Automatically binds appropriate targets based on calling context:
  - Functions called only from host → host target
  - Functions called only from device → device target
  - Functions called from both → device target + host duplicate
- **Call Site Updates**: Updates call sites in externally exposed functions to use appropriate duplicates

- Improved device function extraction and kernel generation
- Better handling of error propagation for different device types
- Enhanced buffer declaration and parameter management

- Support for `__device__` function calls in CUDA kernels
- Proper function signature generation for device functions
- Enhanced calling convention handling

- Updated build pipeline to handle device call compilation
- Improved target-specific compilation logic

The following example demonstrates how the BindTarget pass handles functions called from both host and device contexts:

```python
@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(a: T.int32, b: T.int32) -> T.int32:
        return a + b

    @T.prim_func
    def main(
        A: T.Buffer((128, 128), "int32"),
        B: T.Buffer((128, 128), "int32"),
        C: T.Buffer((128, 128), "int32"),
    ):
        T.func_attr({"global_symbol": "main"})
        length: T.int32 = Module.add(64, 64)  # Call from host
        for bx in T.thread_binding(length, "blockIdx.x"):
            for tx in T.thread_binding(length, "threadIdx.x"):
                C[bx, tx] = Module.add(A[bx, tx], B[bx, tx])  # Call from device
```

After applying `BindTarget(cuda, host="llvm")`, the pass automatically:
1. Creates a device version of `add` with CUDA target
2. Creates a host duplicate `add_host` with LLVM target
3. Updates the main function to call `add_host` from host context and `add` from device context

This enables seamless compilation of mixed host/device code while maintaining proper target-specific optimizations and code generation.

- **Automatic Target Binding**: No manual target annotation required for most use cases
- **Heterogeneous Compilation**: Proper support for functions called from multiple contexts
- **Code Reuse**: Shared functions can be called from both host and device without duplication
- **Performance**: Maintains target-specific optimizations for each context
- **Developer Experience**: Simplifies writing mixed host/device code

The implementation is backward compatible and integrates seamlessly with existing TVM compilation pipelines.
new_mod->Update(gvar,
WithAttr(std::move(prim_func), tvm::attr::kTarget, target_without_host));
} else {
// Rule 4.4: Not called by any context
Copy link
Contributor

@Kathryn-cat Kathryn-cat Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to document a bit on the behavior of rule 4.4 "the function is not called by any host or device", currently it defaults to target_without_host, would be helpful to document a bit on the desired behavior and what are likely issues that might need check further.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Kathryn-cat for pointing it out. Actually, I'm not sure why the current behaviour works like this. Also, there is a testcase for it. To keep the project working well, I choose to keep the current behaviour. However, IMO, if a function is not called by any global function, it will be a useless function, so somehow we can directly remove it from the IRModule.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I think we can keep the current behavior for now.

@Kathryn-cat
Copy link
Contributor

@Hzfengsy LGTM! Just added a small comment and I think we're good to go.

@Hzfengsy Hzfengsy merged commit acdc164 into apache:main Jul 11, 2025
15 checks passed
ShiboXing pushed a commit to ShiboXing/tvm that referenced this pull request Aug 10, 2025
[TIR][Target] Support device call compilation

This PR introduces support for device call compilation in TVM by enhancing the BindTarget pass to properly handle functions called from both host and device contexts. The key improvement is the ability to automatically create host-specific duplicates of functions that are called from both host and device code, ensuring proper target binding for heterogeneous compilation.

- **Function Classification**: Analyzes call patterns to identify functions called from host vs device contexts
- **Smart Target Binding**: Automatically binds appropriate targets based on calling context:
  - Functions called only from host → host target
  - Functions called only from device → device target
  - Functions called from both → device target + host duplicate
- **Call Site Updates**: Updates call sites in externally exposed functions to use appropriate duplicates

- Improved device function extraction and kernel generation
- Better handling of error propagation for different device types
- Enhanced buffer declaration and parameter management

- Support for `__device__` function calls in CUDA kernels
- Proper function signature generation for device functions
- Enhanced calling convention handling

- Updated build pipeline to handle device call compilation
- Improved target-specific compilation logic

The following example demonstrates how the BindTarget pass handles functions called from both host and device contexts:

```python
@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(a: T.int32, b: T.int32) -> T.int32:
        return a + b

    @T.prim_func
    def main(
        A: T.Buffer((128, 128), "int32"),
        B: T.Buffer((128, 128), "int32"),
        C: T.Buffer((128, 128), "int32"),
    ):
        T.func_attr({"global_symbol": "main"})
        length: T.int32 = Module.add(64, 64)  # Call from host
        for bx in T.thread_binding(length, "blockIdx.x"):
            for tx in T.thread_binding(length, "threadIdx.x"):
                C[bx, tx] = Module.add(A[bx, tx], B[bx, tx])  # Call from device
```

After applying `BindTarget(cuda, host="llvm")`, the pass automatically:
1. Creates a device version of `add` with CUDA target
2. Creates a host duplicate `add_host` with LLVM target
3. Updates the main function to call `add_host` from host context and `add` from device context

This enables seamless compilation of mixed host/device code while maintaining proper target-specific optimizations and code generation.

- **Automatic Target Binding**: No manual target annotation required for most use cases
- **Heterogeneous Compilation**: Proper support for functions called from multiple contexts
- **Code Reuse**: Shared functions can be called from both host and device without duplication
- **Performance**: Maintains target-specific optimizations for each context
- **Developer Experience**: Simplifies writing mixed host/device code

The implementation is backward compatible and integrates seamlessly with existing TVM compilation pipelines.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants