Skip to content

Conversation

@tlopex
Copy link
Member

@tlopex tlopex commented Aug 25, 2025

Overview

This PR implements native Python function support in TVM Relax through the @I.pyfunc decorator and BasePyModule, which enable seamless integration between TVM's compilation pipeline and Python/PyTorch runtime environments. This enhancement allows users to write Python functions directly in TVMScript that can interoperate with Relax and TIR functions that provides enhanced debugging capabilities and leveraging existing PyTorch operator libraries.

Key Features

TVMScript Parser Enhancement

  • @I.pyfunc decorator: Marks Python functions for integration into IRModules
  • Dual storage format: Stores both raw string representation (for TVMScript printing) and captured PackedFunc (for runtime execution)
  • ExternFunc representation: Each Python function is represented as an ExternFunc node with attributes storing source code and runtime wrapper

Complete BasePyModule Implementation

  • DLPack-based tensor conversion: Seamless conversion between PyTorch tensors and TVM NDArrays
  • Cross-function interoperability: Python functions can call Relax/TIR functions and vice versa
  • JIT compilation: Delays compilation until module instantiation for flexible late-stage modifications
  • Dynamic function registration: Supports runtime addition of Python functions

Future Work

  • TVMScript printer for IRModules with Python functions: Print IRModules in proper format with high-level operator mapping from Relax ops to PyTorch ops, handling symbolic shapes
  • R.call_py_func primitive: Introduce Relax primitive to invoke corresponding PackedFunc of specified Python functions at runtime

Example: IRModule with pyfunc

import tvm
from tvm import relax, tir
from tvm.relax.base_py_module import BasePyModule
from tvm.script import ir as I, relax as R, tir as T
from tvm.runtime import Device
import torch


@I.ir_module
class IRModuleWithPyFunc(BasePyModule):
    """Example IRModule with Python function.
    The base class BasePyModule implements the logic of cross-function calls
    and JIT compilation in Python.
    We only allow Python functions in IRModules that subclass the BasePyModule.
    """

    @I.pyfunc
    def python_add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Python function that can be called from Relax functions."""
        # Convert inputs to TVM NDArrays via DLPack
        x_tvm = self._convert_pytorch_to_tvm(x)
        y_tvm = self._convert_pytorch_to_tvm(y)
        
        # Call the compiled TIR function
        result = self.call_tir(self.add_tir, [x_tvm, y_tvm], 
                             out_sinfo=R.Tensor((5,), "float32"))
        
        # Convert result back to original format
        return self._convert_tvm_to_pytorch(result)

    @T.prim_func
    def add_tir(
        var_x: T.handle,
        var_y: T.handle,
        var_out: T.handle,
    ):
        x = T.match_buffer(var_x, (5,), "float32")
        y = T.match_buffer(var_y, (5,), "float32")
        out = T.match_buffer(var_out, (5,), "float32")
        
        for i in range(5):
            out[i] = x[i] + y[i]

    @R.function
    def main_relax(x: R.Tensor((5,), "float32"), 
                   y: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
        return R.add(x, y)


def main():
    """Main function showing IRModule with Python function support."""
    # Create the IRModuleWithPyFunc instance
    module = IRModuleWithPyFunc()
    
    # Execute DLPack conversion
    x_torch = torch.randn(5, dtype=torch.float32)
    y_torch = torch.randn(5, dtype=torch.float32)
    
    # Convert via DLPack
    x_tvm = module._convert_pytorch_to_tvm(x_torch)
    y_tvm = module._convert_pytorch_to_tvm(y_torch)
    
    # Convert back
    x_back = module._convert_tvm_to_pytorch(x_tvm)
    y_back = module._convert_tvm_to_pytorch(y_tvm)
    
    # Execute cross-function calls
    tir_result = module.call_tir("add_tir", [x_torch, y_torch], 
                                out_sinfo=R.Tensor((5,), "float32"))
    relax_result = module.main_relax(x_torch, y_torch)
    python_result = module.python_add(x_torch, y_torch)
    
    return module, (x_torch, y_torch, x_tvm, y_tvm, x_back, y_back), (tir_result, relax_result, python_result)


if __name__ == "__main__":
    main()



# Example usage with verification:
# result = main()
# assert result is not None, "Function should return results"
# module, dlpack_results, cross_call_results = result
# assert len(dlpack_results) == 6, "DLPack results should contain 6 elements"
# assert len(cross_call_results) == 3, "Cross-call results should contain 3 elements"

@tlopex
Copy link
Member Author

tlopex commented Aug 25, 2025

cc @MasterJH5574
Could you please have a look at CI.
Why it said [2025-08-25T05:02:07.869Z] ImportError: Error importing plugin "tvm.testing.plugin": No module named 'torch'

@MasterJH5574
Copy link
Contributor

Could you please have a look at CI. Why it said [2025-08-25T05:02:07.869Z] ImportError: Error importing plugin "tvm.testing.plugin": No module named 'torch'

@tlopex Likely the docker image for wasm CI doesn't come with torch. Can we change to lazy import torch, i.e. import torch inside where it is used?

@tlopex tlopex requested a review from MasterJH5574 August 25, 2025 23:09
Copy link
Contributor

@MasterJH5574 MasterJH5574 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks very good, thank you @tlopex! Would be great to add an example of “IRModule with pyfunc” in the PR description, and you can also reply the example to the forum discussion thread after this PR is merged.

@MasterJH5574 MasterJH5574 merged commit 2012d55 into apache:main Aug 27, 2025
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants