Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions python/tvm/relax/base_py_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,11 @@ def call_dps_packed(self, func_name: str, args, out_sinfo):
return out[0] if len(out) == 1 else out

def call_py_func(self, func_name: str, args):
"""Call a Python function stored in the IRModule's pyfuncs."""
if func_name not in self.ir_mod.pyfuncs:
raise ValueError(f"Python function '{func_name}' not found in IRModule pyfuncs")
py_func = self.ir_mod.pyfuncs[func_name]
converted_args = self._convert_tvm_to_pytorch(args)
return py_func(*converted_args)
"""Call a Python function stored in the module's pyfuncs."""
if func_name not in self.pyfuncs:
raise ValueError(f"Python function '{func_name}' not found in module pyfuncs")
py_func = self.pyfuncs[func_name]
return py_func(self, *args)

def _create_output_tensors(self, out_sinfo, in_args=None):
# pylint: disable=import-outside-toplevel
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
call_dps_packed,
call_inplace_packed,
call_pure_packed,
call_py_func,
call_tir,
call_tir_inplace,
call_tir_with_grad,
Expand Down
36 changes: 36 additions & 0 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,42 @@ def call_dps_packed(
return _ffi_api.call_dps_packed(func, args, out_sinfo) # type: ignore


@args_converter.auto
def call_py_func(
func_name: str,
args: Expr,
out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]],
) -> Call:
"""
Call a Python function and return the output.

Parameters
----------
func_name : str
The name of the Python function to call. This should correspond to a function
in the IRModule's pyfuncs attribute.

args : Expr
The input arguments.

out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]]
The structure info of the call_py_func output.
It should be a single or a list of TensorStructInfo. Each one denotes the
structure info of a returned tensor.

Returns
-------
ret: Call
A call node for the call_py_func operator.
"""
args = _wrap_inline_arg_tuple(args)

if not isinstance(out_sinfo, list):
out_sinfo = [out_sinfo]

return _ffi_api.call_py_func(func_name, args, out_sinfo) # type: ignore


@args_converter.auto
def call_builtin_with_ctx(
func: Union[str, Expr],
Expand Down
52 changes: 52 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Expr,
ExternFunc,
ShapeExpr,
StringImm,
TupleGetItem,
Var,
VarBinding,
Expand Down Expand Up @@ -64,6 +65,7 @@
call_dps_packed,
call_inplace_packed,
call_pure_packed,
call_py_func as _call_py_func,
call_tir,
call_tir_inplace,
call_tir_with_grad,
Expand Down Expand Up @@ -451,6 +453,55 @@ def call_packed(
return Call(op, args, attrs=attrs, sinfo_args=sinfo_args)


@args_converter.auto
def call_py_func(
py_func_name: py_str,
*args: Expr,
out_sinfo: Union[StructInfo, List[StructInfo]],
) -> Call:
"""Create a relax Call, which calls a Python function.

Parameters
----------
py_func_name: str
The name of the Python function to call. This should correspond to a function
in the IRModule's pyfuncs attribute.
*args : Expr
The arguments.
out_sinfo: Union[StructInfo, List[StructInfo]]
The structure info of the call_py_func output.
It should be a single or a list of TensorStructInfo. Each one denotes the
structure info of a returned tensor.

Returns
-------
call: Call
The created Relax Call for call_py_func operator.
"""
if isinstance(out_sinfo, py_tuple): # type: ignore
out_sinfo = list(out_sinfo)
elif not isinstance(out_sinfo, list):
out_sinfo = [out_sinfo]

out_sinfo = [
(
sinfo()
if callable(sinfo)
else sinfo.asobject()
if isinstance(sinfo, ObjectConvertible)
else sinfo
)
for sinfo in out_sinfo
]

# Convert string to StringImm
try:
func_name_imm = StringImm(py_func_name) if hasattr(py_func_name, "strip") else py_func_name
except (TypeError, ValueError, AttributeError):
func_name_imm = StringImm(py_func_name)
return _call_py_func(func_name_imm, args, out_sinfo)


def _sinfo_arg_wrapper(func):
"""A wrapper to convert StructInfoProxies to StructInfo for builtin operators with sinfo_args"""

Expand Down Expand Up @@ -743,6 +794,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"call_tir_inplace",
"call_tir_with_grad",
"call_dps_packed",
"call_py_func",
"call_builtin_with_ctx",
"ceil",
"clip",
Expand Down
64 changes: 64 additions & 0 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,70 @@ TVM_FFI_STATIC_INIT_BLOCK({
refl::GlobalDef().def("relax.op.call_dps_packed", MakeCallDPSPacked);
});

// call_py_func

StructInfo InferStructInfoCallPyFunc(const Call& call, const BlockBuilder& ctx) {
if (call->sinfo_args.size() != 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "sinfo_args should have exact 1 output struct info.");
}
return call->sinfo_args[0];
}

void ValidateCallPyFunc(Call call) {
// Validate that the function name is a string literal
auto func_name = call->args[0];
CHECK(func_name->IsInstance<StringImmNode>())
<< "Operation " << call->op << " expects the first argument to be a string literal "
<< "specifying the Python function name. However, the first argument " << func_name
<< " is not a string literal.";

// Validate that args is a tuple
Expr arg_tuple = call->args[1];
CHECK(arg_tuple->struct_info_.as<TupleStructInfoNode>())
<< "Operation " << call->op << " expects the second argument to be a tuple of relax Expr. "
<< "However, the second argument " << arg_tuple << " has struct info "
<< arg_tuple->struct_info_ << ".";

CHECK(arg_tuple.as<TupleNode>() || arg_tuple.as<VarNode>())
<< "Operation " << call->op << " must hold its arguments as an in-line tuple. "
<< "However, " << call << " has arguments " << arg_tuple
<< ", which is neither an in-line tuple, "
<< "nor a variable binding that may be normalized to an in-line tuple.";
}

TVM_REGISTER_OP("relax.call_py_func")
.set_num_inputs(2)
.add_argument("func_name", "StringImm", "The name of the Python function to call.")
.add_argument("args", "Tuple", "The input arguments.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallPyFunc)
.set_attr<FValidate>("FValidate", ValidateCallPyFunc)
.set_attr<Bool>("FPurity", Bool(true));

Expr MakeCallPyFunc(StringImm func_name, Tuple args, ffi::Array<TensorStructInfo> out_sinfo_list) {
for (const TensorStructInfo& sinfo : out_sinfo_list) {
const auto* shape = sinfo->shape.as<ShapeExprNode>();
CHECK(shape != nullptr) << "out_sinfo of call_py_func should have defined ShapeExpr as shape. "
"However, one given structure info is "
<< sinfo;
}

StructInfo out_sinfo{nullptr};
if (out_sinfo_list.size() == 1) {
out_sinfo = out_sinfo_list[0];
} else {
out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()});
}

static const Op& op = Op::Get("relax.call_py_func");
return Call(op, {func_name, args}, {}, {out_sinfo});
}

TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("relax.op.call_py_func", MakeCallPyFunc);
});

// call builtin
StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilder& ctx) {
if (call->sinfo_args.size() == 0) {
Expand Down
Loading
Loading