Skip to content

Conversation

@Hzfengsy
Copy link
Member

This commit implements T.thread_return() functionality that allows threads to exit early from CUDA kernels. The feature is useful for cases where threads need to conditionally return based on thread indices or other conditions.

Key changes:

  • Add thread_return builtin in TIR
  • Implement CUDA codegen for thread_return
  • Add Python bindings for T.thread_return()
  • Update TIR IR builder to support thread_return
  • Add tests demonstrating thread_return usage

Example usage:

@T.prim_func
def main(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")):
    for i in T.thread_binding(16, thread="blockIdx.x"):
        for j in T.thread_binding(32, thread="threadIdx.x"):
            if j >= 16:
                T.thread_return()  # Early exit for threads with j >= 16
            B[i, j] = A[i, j]

and generate code is:

extern "C" __global__ void __launch_bounds__(32) main_kernel(float* __restrict__ A, float* __restrict__ B) {
  if (16 <= ((int)threadIdx.x)) {
    return;
  }
  B[((((int)blockIdx.x) * 16) + ((int)threadIdx.x))] = A[((((int)blockIdx.x) * 16) + ((int)threadIdx.x))];
}

This commit implements T.thread_return() functionality that allows threads
to exit early from CUDA kernels. The feature is useful for cases where
threads need to conditionally return based on thread indices or other
conditions.

Key changes:
- Add thread_return builtin in TIR
- Implement CUDA codegen for thread_return
- Add Python bindings for T.thread_return()
- Update TIR IR builder to support thread_return
- Add tests demonstrating thread_return usage

Example usage:
```python
@T.prim_func
def main(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")):
    for i in T.thread_binding(16, thread="blockIdx.x"):
        for j in T.thread_binding(32, thread="threadIdx.x"):
            if j >= 16:
                T.thread_return()  # Early exit for threads with j >= 16
            B[i, j] = A[i, j]
```

and generate code is:

```cuda
extern "C" __global__ void __launch_bounds__(32) main_kernel(float* __restrict__ A, float* __restrict__ B) {
  if (16 <= ((int)threadIdx.x)) {
    return;
  }
  B[((((int)blockIdx.x) * 16) + ((int)threadIdx.x))] = A[((((int)blockIdx.x) * 16) + ((int)threadIdx.x))];
}
```
@Hzfengsy
Copy link
Member Author

cc @LeiWang1999

@tqchen tqchen merged commit ea4369c into apache:main Jul 14, 2025
13 checks passed
ShiboXing pushed a commit to ShiboXing/tvm that referenced this pull request Aug 10, 2025
…pache#18134)

This commit implements T.thread_return() functionality that allows threads
to exit early from CUDA kernels. The feature is useful for cases where
threads need to conditionally return based on thread indices or other
conditions.

Key changes:
- Add thread_return builtin in TIR
- Implement CUDA codegen for thread_return
- Add Python bindings for T.thread_return()
- Update TIR IR builder to support thread_return
- Add tests demonstrating thread_return usage

Example usage:
```python
@T.prim_func
def main(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")):
    for i in T.thread_binding(16, thread="blockIdx.x"):
        for j in T.thread_binding(32, thread="threadIdx.x"):
            if j >= 16:
                T.thread_return()  # Early exit for threads with j >= 16
            B[i, j] = A[i, j]
```

and generate code is:

```cuda
extern "C" __global__ void __launch_bounds__(32) main_kernel(float* __restrict__ A, float* __restrict__ B) {
  if (16 <= ((int)threadIdx.x)) {
    return;
  }
  B[((((int)blockIdx.x) * 16) + ((int)threadIdx.x))] = A[((((int)blockIdx.x) * 16) + ((int)threadIdx.x))];
}
```
@Hzfengsy Hzfengsy deleted the thread_return branch September 16, 2025 14:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants