Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(self, functions=None, attrs=None, global_infos=None):
attrs,
global_infos,
)
self.pyfuncs = {}

def clone(self) -> "IRModule":
return _ffi_api.Module_Clone(self)
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@
# utils
from .utils import convert_to_expr

# BasePyModule
from .base_py_module import BasePyModule

# Import submodules in the last to avoid dependency
from . import exec_builder
from . import expr
Expand Down
385 changes: 385 additions & 0 deletions python/tvm/relax/base_py_module.py

Large diffs are not rendered by default.

69 changes: 69 additions & 0 deletions python/tvm/script/parser/core/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, Dict, Union

import tvm
from tvm.relax import ExternFunc
from ....ir.module import IRModule
from ...ir_builder import IRBuilder
from . import doc
Expand Down Expand Up @@ -86,12 +87,14 @@ def parse(
extra_vars = _default_globals()

ann = {}
all_pyfuncs = {}
if inspect.isfunction(program):
ann = {program.__name__: program.__annotations__}
elif inspect.isclass(program):
for name, func in program.__dict__.items():
if inspect.isfunction(func):
ann[name] = func.__annotations__
all_pyfuncs[name] = func

source = Source(program)
parser = Parser(source, ann)
Expand All @@ -101,6 +104,10 @@ def parse(
except ParserError as err:
parser.report_error(err.node, err.args[0])
ret = builder.get()
# Attach pyfuncs to the IRModule
if inspect.isclass(program) and isinstance(ret, IRModule):
_attach_pyfuncs_to_irmodule(ret, all_pyfuncs)

# check well-formedness in both Relax and TIR
if check_well_formed:
check_ret = ret
Expand All @@ -122,3 +129,65 @@ def parse(
err=f"{WELL_FORMED_ERROR_MESSAGE}\n\nTraceback: {str(err)}",
)
return ret


def _create_python_packed_func(pyfunc):
"""Create a PackedFunc wrapper for a Python function.

This function creates a PackedFunc that can be called from TVM runtime
and will execute the original Python function.

Parameters
----------
pyfunc : Callable
The Python function to wrap.

Returns
-------
PackedFunc
A PackedFunc that wraps the Python function.
"""

def packed_func_wrapper(*args, **kwargs):
"""Wrapper function that calls the original Python function."""
try:
result = pyfunc(*args, **kwargs)
return result
except Exception as error:
print(f"Error calling Python function {pyfunc.__name__}: {error}")
raise

return packed_func_wrapper


def _attach_pyfuncs_to_irmodule(irmodule, all_pyfuncs):
"""Attach Python functions to IRModule with reduced nesting."""
if not all_pyfuncs:
return

if not hasattr(irmodule, "pyfuncs"):
irmodule.pyfuncs = {}

for global_var, func in irmodule.functions_items():
if not isinstance(func, ExternFunc):
continue
if not func.attrs.get("is_pyfunc", False):
continue

pyfunc_name = global_var.name_hint
if pyfunc_name not in all_pyfuncs:
continue

pyfunc = all_pyfuncs[pyfunc_name]
irmodule.pyfuncs[pyfunc_name] = pyfunc

try:
source_code = inspect.getsource(pyfunc)
func = func.with_attr("python_source", source_code)
except (OSError, TypeError):
func = func.with_attr("python_source", f"# Source unavailable for {pyfunc_name}")

packed_func = _create_python_packed_func(pyfunc)
func = func.with_attr("python_packed_func", packed_func)

irmodule[global_var] = func
35 changes: 35 additions & 0 deletions python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,8 @@ class Parser(doc.NodeVisitor):
function_annotations: Optional[Dict[str, Dict[str, Any]]]
var_table: VarTable
inside_function: bool # whether we are within a function
current_class: Optional[str] = None # current class being parsed
base_py_module_context: bool = False # whether current class inherits from BasePyModule

def __init__(
self,
Expand Down Expand Up @@ -414,6 +416,39 @@ def pop_token():

return _deferred(pop_token)

def set_class_context(self, class_name: str, is_base_py_module: bool = False):
"""Set the current class context for parsing.

Parameters
----------
class_name : str
The name of the current class being parsed.
is_base_py_module : bool
Whether the current class inherits from BasePyModule.
"""
self.current_class = class_name
self.base_py_module_context = is_base_py_module

def _get_current_class_context(self) -> Optional[str]:
"""Get the current class context.

Returns
-------
Optional[str]
The name of the current class, or None if not in a class context.
"""
return self.current_class

def _is_base_py_module_context(self) -> bool:
"""Check if the current class context allows Python functions.

Returns
-------
bool
True if Python functions are allowed in the current context.
"""
return self.base_py_module_context

def with_diag_source(self, source: Source):
"""Add a new source as with statement.

Expand Down
3 changes: 2 additions & 1 deletion python/tvm/script/parser/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from tvm.ir import Range
from ...ir_builder.ir import * # pylint: disable=redefined-builtin
from . import parser as _parser
from .entry import ir_module
from .entry import ir_module, pyfunc


__all__ = [
Expand All @@ -28,5 +28,6 @@
"dummy_global_info",
"Range",
"lookup_vdevice",
"pyfunc",
"vdevice",
]
94 changes: 91 additions & 3 deletions python/tvm/script/parser/ir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
"""The entry point of TVM parser for ir module."""

import inspect
from typing import Optional, Type
from typing import Callable, Optional, Type

from tvm.ir import IRModule
from tvm.ir import IRModule, GlobalVar
from tvm.relax.expr import ExternFunc
from tvm.relax.base_py_module import BasePyModule
from tvm import cpu, ir

from .._core import parse, utils

Expand Down Expand Up @@ -47,7 +50,86 @@ def ir_module(mod: Optional[Type] = None, check_well_formed: bool = True) -> IRM
def decorator_wrapper(mod):
if not inspect.isclass(mod):
raise TypeError(f"Expect a class, but got: {mod}")

# Check BasePyModule inheritance
base_py_module_inherited = any(base.__name__ == "BasePyModule" for base in mod.__bases__)

m = parse(mod, utils.inspect_class_capture(mod), check_well_formed=check_well_formed)

if base_py_module_inherited:
# Collect pyfunc methods
pyfunc_methods = [
name
for name, attr in mod.__dict__.items()
if hasattr(attr, "dispatch_token") and attr.dispatch_token == "pyfunc"
]

mod._pyfunc_methods = pyfunc_methods

# Create ExternFunc nodes

for method_name in pyfunc_methods:
try:
existing_gvars = [
global_var
for global_var in m.get_global_vars()
if global_var.name_hint == method_name
]

extern_func = ExternFunc(method_name)
extern_func = extern_func.with_attr("is_pyfunc", True)
extern_func = extern_func.with_attr("function_type", "python")
extern_func = extern_func.with_attr("python_function_name", method_name)
extern_func = extern_func.with_attr(
"python_source", f"# Source for {method_name}"
)
extern_func = extern_func.with_attr("python_packed_func", None)

if existing_gvars:
m[existing_gvars[0]] = extern_func
else:
m[GlobalVar(method_name)] = extern_func

except Exception: # pylint: disable=broad-exception-caught
continue

class ModuleFactory:
"""Factory class for creating BasePyModule instances with Python functions."""

def __init__(self, module, pyfunc_methods, original_class):
self.ir_module = module
self.pyfunc_methods = pyfunc_methods
self.original_class = original_class

def __call__(self, device=None, target=None):

if device is None:
device = cpu(0)

instance_ir_mod = ir.IRModule()
for global_var, func in self.ir_module.functions_items():
instance_ir_mod[global_var] = func

instance = BasePyModule(instance_ir_mod, device, target)

for method_name in self.pyfunc_methods:
if hasattr(self.original_class, method_name):
method = getattr(self.original_class, method_name)
instance.add_python_function(method_name, method)

return instance

def __getattr__(self, name):
if hasattr(self.ir_module, name):
return getattr(self.ir_module, name)
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)

factory = ModuleFactory(m, pyfunc_methods, mod)
setattr(factory, "__name__", mod.__name__)
return factory

setattr(m, "__name__", mod.__name__)
return m

Expand All @@ -61,4 +143,10 @@ def decorator_wrapper(mod):
return decorator_wrapper


setattr(ir_module, "dispatch_token", "ir")
def pyfunc(func: Callable):
# Set the dispatch_token on the decorated function
setattr(func, "dispatch_token", "pyfunc")
return func


setattr(pyfunc, "dispatch_token", "pyfunc")
Loading
Loading