Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions ffi/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,20 @@ if (TVM_FFI_BUILD_PYTHON_MODULE)
target_compile_features(tvm_ffi_cython PRIVATE cxx_std_17)
target_link_libraries(tvm_ffi_cython PRIVATE tvm_ffi_header)
target_link_libraries(tvm_ffi_cython PRIVATE tvm_ffi_shared)
# Set RPATH for tvm_ffi_cython to find tvm_ffi_shared.so relatively
if(APPLE)
# macOS uses @loader_path
set_target_properties(tvm_ffi_cython PROPERTIES
INSTALL_RPATH "@loader_path/lib"
BUILD_WITH_INSTALL_RPATH ON
)
elseif(LINUX)
# Linux uses $ORIGIN
set_target_properties(tvm_ffi_cython PROPERTIES
INSTALL_RPATH "\$ORIGIN/lib"
BUILD_WITH_INSTALL_RPATH ON
)
endif()
install(TARGETS tvm_ffi_cython DESTINATION .)

########## Installing the source ##########
Expand Down
33 changes: 21 additions & 12 deletions ffi/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

[project]
name = "apache-tvm-ffi"
version = "0.1.0a0"
version = "0.1.0a2"
description = "tvm ffi"

authors = [{ name = "TVM FFI team" }]
Expand All @@ -32,6 +32,7 @@ classifiers = [
]
keywords = ["machine learning", "inference"]
requires-python = ">=3.9"

dependencies = []


Expand All @@ -40,8 +41,9 @@ Homepage = "https://github.com/apache/tvm/ffi"
GitHub = "https://github.com/apache/tvm/ffi"

[project.optional-dependencies]
torch = ["torch"]
test = ["pytest"]
# setup tools is needed by torch jit for best perf
torch = ["torch", "setuptools"]
test = ["pytest", "numpy", "torch"]

[project.scripts]
tvm-ffi-config = "tvm_ffi.config:__main__"
Expand Down Expand Up @@ -122,20 +124,27 @@ skip_gitignore = true

[tool.cibuildwheel]
build-verbosity = 1
# skip pp and low python version
# sdist should be sufficient

# only build up to cp312, cp312
# will be abi3 and can be used in future versions
build = [
"cp39-*",
"cp310-*",
"cp311-*",
"cp312-*",
]
skip = [
"cp36-*",
"cp37-*",
"cp38-*",
"*musllinux*"
]
# we only need to test on cp312
test-skip = [
"cp39-*",
"cp310-*",
"cp311-*",
"pp*",
"*musllinux*",
] # pypy doesn't play nice with pybind11
]
# focus on testing abi3 wheel
build-frontend = "build[uv]"
test-command = "pytest {project}/tests -m "
test-command = "pytest {package}/tests/python -vvs"
test-extras = ["test"]

[tool.cibuildwheel.linux]
Expand Down
4 changes: 2 additions & 2 deletions ffi/python/tvm_ffi/cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ except ImportError:
def load_torch_get_current_cuda_stream():
"""Create a faster get_current_cuda_stream for torch through cpp extension.
"""
from torch.utils import cpp_extension

source = """
#include <c10/cuda/CUDAStream.h>

Expand All @@ -44,6 +42,7 @@ def load_torch_get_current_cuda_stream():
"""Fallback with python api"""
return torch.cuda.current_stream(device_id).cuda_stream
try:
from torch.utils import cpp_extension
result = cpp_extension.load_inline(
name="get_current_cuda_stream",
cpp_sources=[source],
Expand All @@ -56,6 +55,7 @@ def load_torch_get_current_cuda_stream():
except Exception:
return fallback_get_current_cuda_stream


if torch is not None:
# when torch is available, jit compile the get_current_cuda_stream function
# the torch caches the extension so second loading is faster
Expand Down
40 changes: 23 additions & 17 deletions ffi/python/tvm_ffi/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""dtype class."""
# pylint: disable=invalid-name
from enum import IntEnum
import numpy as np

from . import core

Expand Down Expand Up @@ -58,22 +57,7 @@ class dtype(str):

__slots__ = ["__tvm_ffi_dtype__"]

NUMPY_DTYPE_TO_STR = {
np.dtype(np.bool_): "bool",
np.dtype(np.int8): "int8",
np.dtype(np.int16): "int16",
np.dtype(np.int32): "int32",
np.dtype(np.int64): "int64",
np.dtype(np.uint8): "uint8",
np.dtype(np.uint16): "uint16",
np.dtype(np.uint32): "uint32",
np.dtype(np.uint64): "uint64",
np.dtype(np.float16): "float16",
np.dtype(np.float32): "float32",
np.dtype(np.float64): "float64",
}
if hasattr(np, "float_"):
NUMPY_DTYPE_TO_STR[np.dtype(np.float_)] = "float64"
NUMPY_DTYPE_TO_STR = {}

def __new__(cls, content):
content = str(content)
Expand Down Expand Up @@ -122,6 +106,28 @@ def lanes(self):
return self.__tvm_ffi_dtype__.lanes


try:
# this helps to make numpy as optional
# although almost in all cases we want numpy
import numpy as np

dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.bool_)] = "bool"
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.int8)] = "int8"
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.int16)] = "int16"
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.int32)] = "int32"
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.int64)] = "int64"
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.uint8)] = "uint8"
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.uint16)] = "uint16"
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.uint32)] = "uint32"
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.uint64)] = "uint64"
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.float16)] = "float16"
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.float32)] = "float32"
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.float64)] = "float64"
if hasattr(np, "float_"):
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.float_)] = "float64"
except ImportError:
pass

try:
import ml_dtypes

Expand Down
7 changes: 3 additions & 4 deletions ffi/scripts/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ set -euxo pipefail

BUILD_TYPE=Release

rm -rf build/CMakeFiles build/CMakeCache.txt
cmake -G Ninja -S . -B build -DTVM_FFI_BUILD_TESTS=ON -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_FLAGS="-O3"
cmake --build build --parallel 16 --clean-first --config ${BUILD_TYPE} --target tvm_ffi_tests
cmake -G Ninja -S . -B build -DTVM_FFI_BUILD_TESTS=ON -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_CXX_COMPILER_LAUNCHER=ccache
cmake --build build --clean-first --config ${BUILD_TYPE} --target tvm_ffi_tests
GTEST_COLOR=1 ctest -V -C ${BUILD_TYPE} --test-dir build --output-on-failure
22 changes: 11 additions & 11 deletions ffi/tests/python/test_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ def test_error_from_cxx():
tvm_ffi.convert(lambda x: x)()


@pytest.mark.xfail(
"32bit" in platform.architecture() or platform.system() == "Windows",
reason="May fail if debug symbols are missing",
)
def test_error_from_nested_pyfunc():
fapply = tvm_ffi.convert(lambda f, *args: f(*args))
cxx_test_raise_error = tvm_ffi.get_global_func("testing.test_raise_error")
Expand All @@ -78,13 +74,17 @@ def raise_error():
traceback = e.__tvm_ffi_error__.traceback
assert e.__tvm_ffi_error__.same_as(record_object[0])
assert traceback.count("TestRaiseError") == 1
assert traceback.count("TestApply") == 1
assert traceback.count("<lambda>") == 1
pos_cxx_raise = traceback.find("TestRaiseError")
pos_cxx_apply = traceback.find("TestApply")
pos_lambda = traceback.find("<lambda>")
assert pos_cxx_raise > pos_lambda
assert pos_lambda > pos_cxx_apply
# The following lines may fail if debug symbols are missing
try:
assert traceback.count("TestApply") == 1
assert traceback.count("<lambda>") == 1
pos_cxx_raise = traceback.find("TestRaiseError")
pos_cxx_apply = traceback.find("TestApply")
pos_lambda = traceback.find("<lambda>")
assert pos_cxx_raise > pos_lambda
assert pos_lambda > pos_cxx_apply
except Exception as e:
pytest.xfail("May fail if debug symbols are missing")


def test_error_traceback_update():
Expand Down
Loading