Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions ffi/examples/inline_module/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
def main():
mod: Module = tvm_ffi.cpp.load_inline(
name="hello",
cpp_source=r"""
void AddOne(DLTensor* x, DLTensor* y) {
cpp_sources=r"""
void add_one_cpu(DLTensor* x, DLTensor* y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
Expand All @@ -36,16 +36,18 @@ def main():
static_cast<float*>(y->data)[i] = static_cast<float*>(x->data)[i] + 1;
}
}

void add_one_cuda(DLTensor* x, DLTensor* y);
""",
cuda_source=r"""
cuda_sources=r"""
__global__ void AddOneKernel(float* x, float* y, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
y[idx] = x[idx] + 1;
}
}

void AddOneCUDA(DLTensor* x, DLTensor* y) {
void add_one_cuda(DLTensor* x, DLTensor* y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
Expand All @@ -67,8 +69,7 @@ def main():
static_cast<float*>(y->data), n);
}
""",
cpp_functions={"add_one_cpu": "AddOne"},
cuda_functions={"add_one_cuda": "AddOneCUDA"},
functions=["add_one_cpu", "add_one_cuda"],
)

x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
Expand Down
136 changes: 76 additions & 60 deletions ffi/python/tvm_ffi/cpp/load_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@
def _hash_sources(
cpp_source: str,
cuda_source: str,
cpp_functions: Mapping[str, str],
cuda_functions: Mapping[str, str],
functions: Sequence[str] | Mapping[str, str],
extra_cflags: Sequence[str],
extra_cuda_cflags: Sequence[str],
extra_ldflags: Sequence[str],
Expand All @@ -45,12 +44,13 @@ def _hash_sources(
m = hashlib.sha256()
m.update(cpp_source.encode("utf-8"))
m.update(cuda_source.encode("utf-8"))
for name, doc in sorted(cpp_functions.items()):
m.update(name.encode("utf-8"))
m.update(doc.encode("utf-8"))
for name, doc in sorted(cuda_functions.items()):
m.update(name.encode("utf-8"))
m.update(doc.encode("utf-8"))
if isinstance(functions, Mapping):
for name in sorted(functions):
m.update(name.encode("utf-8"))
m.update(functions[name].encode("utf-8"))
else:
for name in sorted(functions):
m.update(name.encode("utf-8"))
for flag in extra_cflags:
m.update(flag.encode("utf-8"))
for flag in extra_cuda_cflags:
Expand Down Expand Up @@ -242,8 +242,10 @@ def _decorate_with_tvm_ffi(source: str, functions: Mapping[str, str]) -> str:
source,
]

for exported_name, func_name_in_source in functions.items():
sources.append(f"TVM_FFI_DLL_EXPORT_TYPED_FUNC({exported_name}, {func_name_in_source});")
for func_name, func_doc in functions.items():
sources.append(f"TVM_FFI_DLL_EXPORT_TYPED_FUNC({func_name}, {func_name});")
_ = func_doc # todo: add support to embed function docstring to the tvm ffi functions.

sources.append("")

return "\n".join(sources)
Expand All @@ -252,26 +254,26 @@ def _decorate_with_tvm_ffi(source: str, functions: Mapping[str, str]) -> str:
def load_inline(
name: str,
*,
cpp_source: str | None = None,
cuda_source: str | None = None,
cpp_functions: Mapping[str, str] | None = None,
cuda_functions: Mapping[str, str] | None = None,
cpp_sources: str | None = None,
cuda_sources: str | None = None,
functions: Sequence[str] | None = None,
extra_cflags: Sequence[str] | None = None,
extra_cuda_cflags: Sequence[str] | None = None,
extra_ldflags: Sequence[str] | None = None,
extra_include_paths: Sequence[str] | None = None,
build_directory: Optional[str] = None,
) -> Module:
"""Compile and load a C++/CUDA tvm ffi module from inline source code.

This function compiles the given C++ and/or CUDA source code into a shared library. Both cpp_source and cuda_source
are compiled to an object file, and then linked together into a shared library. It's possible to only provide
cpp_source or cuda_source.
This function compiles the given C++ and/or CUDA source code into a shared library. Both cpp_sources and
cuda_sources are compiled to an object file, and then linked together into a shared library. It's possible to only
provide cpp_sources or cuda_sources.

The `cpp_functions` and `cuda_functions` parameters are used to specify which functions in the source code
should be exported to the tvm ffi module. The keys of the mapping are the names of the exported functions, and the
values are the names of the functions in the source code. The exported name and the function name in the source code
must be different. The exported name must be a valid C identifier while the function name in the source code can
contain namespace qualifiers.
The `functions` parameter is used to specify which functions in the source code should be exported to the tvm ffi module.
It can be a mapping, a sequence, or a single string. When a mapping is given, the keys are the names of the exported
functions, and the values are docstrings for the functions. When a sequence or a single string is given, they are the
functions needed to be exported, and the docstrings are set to empty strings. A single function name can also be given
as a string, indicating that only one function is to be exported.

Extra compiler and linker flags can be provided via the `extra_cflags`, `extra_cuda_cflags`, and `extra_ldflags`
parameters. The default flags are generally sufficient for most use cases, but you may need to provide additional
Expand All @@ -281,22 +283,24 @@ def load_inline(
any header from tvm ffi and dlpack in your source code. You can also provide additional include paths via the
`extra_include_paths` parameter and include custom headers in your source code.

The compiled shared library is cached in a cache directory to avoid recompilation. The cache directory can be
specified via the `TVM_FFI_CACHE_DIR` environment variable. If not specified, the default cache directory is
`~/.cache/tvm-ffi`.
The compiled shared library is cached in a cache directory to avoid recompilation. The `build_directory` parameter
is provided to specify the build directory. If not specified, a default tvm ffi cache directory will be used.
The default cache directory can be specified via the `TVM_FFI_CACHE_DIR` environment variable. If not specified,
the default cache directory is `~/.cache/tvm-ffi`.

Parameters
----------
name: str
The name of the tvm ffi module.
cpp_source: str, optional
The C++ source code.
cuda_source: str, optional
The CUDA source code.
cpp_functions: Mapping[str, str], optional
The mapping from the exported function name to the function name in the C++ source code.
cuda_functions: Mapping[str, str], optional
The mapping from the exported function name to the function name in the CUDA source code.
cpp_sources: Sequence[str] | str, optional
The C++ source code. It can be a list of sources or a single source.
cuda_sources: Sequence[str] | str, optional
The CUDA source code. It can be a list of sources or a single source.
functions: Mapping[str, str] | Sequence[str] | str, optional
The functions in cpp_sources that will be exported to the tvm ffi module. When a mapping is given, the keys
are the names of the exported functions, and the values are docstrings for the functions. When a sequence or a
single string is given, they are the functions needed to be exported, and the docstrings are set to empty
strings. A single function name can also be given as a string.
extra_cflags: Sequence[str], optional
The extra compiler flags for C++ compilation.
The default flags are:
Expand All @@ -316,46 +320,58 @@ def load_inline(
The extra include paths.
The default include paths are:
- The include path of tvm ffi
build_directory: str, optional
The build directory. If not specified, a default tvm ffi cache directory will be used. By default, the
cache directory is `~/.cache/tvm-ffi`. You can also set the `TVM_FFI_CACHE_DIR` environment variable to
specify the cache directory.

Returns
-------
mod: Module
The loaded tvm ffi module.
"""
if cpp_source is None:
cpp_source = ""
if cuda_source is None:
cuda_source = ""
if cpp_functions is None:
cpp_functions = {}
if cuda_functions is None:
cuda_functions = {}
if cpp_sources is None:
cpp_sources = []
elif isinstance(cpp_sources, str):
cpp_sources = [cpp_sources]
cpp_source = "\n".join(cpp_sources)
if cuda_sources is None:
cuda_sources = []
elif isinstance(cuda_sources, str):
cuda_sources = [cuda_sources]
cuda_source = "\n".join(cuda_sources)
with_cuda = len(cuda_sources) > 0

extra_ldflags = extra_ldflags or []
extra_cflags = extra_cflags or []
extra_cuda_cflags = extra_cuda_cflags or []
extra_include_paths = extra_include_paths or []

# whether we have cuda source in this module
with_cuda = len(cuda_source.strip()) > 0

# add function registration code to sources
cpp_source = _decorate_with_tvm_ffi(cpp_source, cpp_functions)
cuda_source = _decorate_with_tvm_ffi(cuda_source, cuda_functions)
if isinstance(functions, str):
functions = {functions: ""}
elif isinstance(functions, Sequence):
functions = {name: "" for name in functions}
cpp_source = _decorate_with_tvm_ffi(cpp_source, functions)
cuda_source = _decorate_with_tvm_ffi(cuda_source, {})

# determine the cache dir for the built module
cache_dir = os.path.join(
os.environ.get("TVM_FFI_CACHE_DIR", os.path.expanduser("~/.cache/tvm-ffi"))
)
source_hash: str = _hash_sources(
cpp_source,
cuda_source,
cpp_functions,
cuda_functions,
extra_cflags,
extra_cuda_cflags,
extra_ldflags,
extra_include_paths,
)
build_dir: str = os.path.join(cache_dir, "{}_{}".format(name, source_hash))
if build_directory is None:
build_directory = os.environ.get(
"TVM_FFI_CACHE_DIR", os.path.expanduser("~/.cache/tvm-ffi")
)
source_hash: str = _hash_sources(
cpp_source,
cuda_source,
functions,
extra_cflags,
extra_cuda_cflags,
extra_ldflags,
extra_include_paths,
)
build_dir: str = os.path.join(build_directory, "{}_{}".format(name, source_hash))
else:
build_dir = os.path.abspath(build_directory)
os.makedirs(build_dir, exist_ok=True)

# generate build.ninja
Expand Down
Loading