Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ namespace builtin {
* \brief Return value.
*/
TVM_DLL const Op& ret();
/*!
* \brief Return from a GPU thread.
*/
TVM_DLL const Op& thread_return();
/*!
* \brief Reinterpret the value using the target type.
*/
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type);
*/
TVM_DLL PrimExpr ret(PrimExpr value, Span span = Span());

/*!
* \brief Return from a thread.
*
* \param span The location of this operation in the source.
* \return The return expression.
*/
TVM_DLL PrimExpr thread_return(Span span = Span());

/*!
* Query the maximum possible value of dtype.
* \param dtype The data type.
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1927,6 +1927,7 @@ def wrapped(*args, **kwargs):
sqrt = _op_wrapper(_tir_op.sqrt)
tan = _op_wrapper(_tir_op.tan)
tanh = _op_wrapper(_tir_op.tanh)
thread_return = _op_wrapper(_tir_op.thread_return)
trunc = _op_wrapper(_tir_op.trunc)
truncdiv = _op_wrapper(_tir_op.truncdiv)
truncmod = _op_wrapper(_tir_op.truncmod)
Expand Down Expand Up @@ -2205,6 +2206,7 @@ def wrapped(*args, **kwargs):
"sqrt",
"tan",
"tanh",
"thread_return",
"trunc",
"truncdiv",
"truncmod",
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1882,6 +1882,23 @@ def ret(val, span=None):
return _ffi_api.ret(val, span)


def thread_return(span=None):
"""Return from a GPU thread.

Parameters
----------
span : Optional[Span]
The location of this operator in the source code.

Returns
-------
ret : PrimExpr
The return expression
"""

return _ffi_api.thread_return(span)


def any(*args, span=None):
"""Create a new experssion of the union of all conditions in the arguments

Expand Down
2 changes: 2 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
LOG(FATAL) << "Invalid number of lanes for float4_e2m1fn reinterpret: " << lanes;
}
EndScope(ssa_scope);
} else if (op->op.same_as(builtin::thread_return())) {
os << "return";
} else {
CodeGenC::VisitExpr_(op, os);
}
Expand Down
4 changes: 4 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ TIR_DEFINE_BUILTIN_FUNC(ret)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kControlJump))
.set_num_inputs(1);

TIR_DEFINE_BUILTIN_FUNC(thread_return)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kControlJump))
.set_num_inputs(0);

TIR_DEFINE_BUILTIN_FUNC(likely)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kExprAnnotation))
Expand Down
6 changes: 6 additions & 0 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@ PrimExpr ret(PrimExpr value, Span span) {

TVM_FFI_REGISTER_GLOBAL("tir.ret").set_body_typed(ret);

PrimExpr thread_return(Span span) {
return tir::Call(DataType::Void(), tir::builtin::thread_return(), {}, span);
}

TVM_FFI_REGISTER_GLOBAL("tir.thread_return").set_body_typed(thread_return);

// maximum and min limits
PrimExpr max_value(const DataType& dtype, Span span) {
using namespace tir;
Expand Down
17 changes: 17 additions & 0 deletions tests/python/codegen/test_target_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,5 +839,22 @@ def main(
tvm.testing.assert_allclose(c_tvm.numpy(), a_np + b_np)


@tvm.testing.requires_cuda
def test_thread_return():
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")):
for bx in T.thread_binding(32, "blockIdx.x"):
for tx in T.thread_binding(32, "threadIdx.x"):
if bx >= 16 or tx >= 16:
T.thread_return()
B[bx, tx] = A[bx, tx]

lib = tvm.compile(Module, target="cuda")
cuda_code = lib.mod.imported_modules[0].get_source()
assert "return;" in cuda_code


if __name__ == "__main__":
tvm.testing.main()
Loading