Skip to content

Commit 5e6f1b7

Browse files
committed
[Fix] Update FlashInfer JIT header lookup
This PR fixes the tvm/dlpack/dmlc header lookup in the FlashInfer kernel JIT compilation. Prior to this fix, the JIT compilation assumes the environment variable `TVM_SOURCE_DIR` is always defined, which is not always true. This PR fixes the behavior and considers multiple cases, including TVM source builds and pip-installed packages.
1 parent 585d6d2 commit 5e6f1b7

File tree

2 files changed

+45
-7
lines changed

2 files changed

+45
-7
lines changed

python/tvm/libinfo.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,9 @@ def find_include_path(name=None, search_path=None, optional=False):
195195
include_path : list(string)
196196
List of all found paths to header files.
197197
"""
198-
if os.environ.get("TVM_HOME", None):
198+
if os.environ.get("TVM_SOURCE_DIR", None):
199+
source_dir = os.environ["TVM_SOURCE_DIR"]
200+
elif os.environ.get("TVM_HOME", None):
199201
source_dir = os.environ["TVM_HOME"]
200202
else:
201203
ffi_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
@@ -204,7 +206,7 @@ def find_include_path(name=None, search_path=None, optional=False):
204206
if os.path.isdir(os.path.join(source_dir, "include")):
205207
break
206208
else:
207-
raise AssertionError("Cannot find the source directory given ffi_dir: {ffi_dir}")
209+
raise AssertionError(f"Cannot find the source directory given ffi_dir: {ffi_dir}")
208210
third_party_dir = os.path.join(source_dir, "3rdparty")
209211

210212
header_path = []

python/tvm/relax/backend/cuda/flashinfer.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,17 +124,53 @@ def get_object_file_path(src: Path) -> Path:
124124
# ------------------------------------------------------------------------
125125
# 2) Include paths
126126
# ------------------------------------------------------------------------
127-
tvm_home = os.environ["TVM_SOURCE_DIR"]
128127
include_paths = [
129128
FLASHINFER_INCLUDE_DIR,
130129
FLASHINFER_CSRC_DIR,
131130
FLASHINFER_TVM_BINDING_DIR,
132-
Path(tvm_home).resolve() / "include",
133-
Path(tvm_home).resolve() / "ffi" / "include",
134-
Path(tvm_home).resolve() / "ffi" / "3rdparty" / "dlpack" / "include",
135-
Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include",
136131
] + CUTLASS_INCLUDE_DIRS
137132

133+
if os.environ.get("TVM_SOURCE_DIR", None) or os.environ.get("TVM_HOME", None):
134+
# Respect TVM_SOURCE_DIR and TVM_HOME if they are set
135+
tvm_home = (
136+
os.environ["TVM_SOURCE_DIR"]
137+
if os.environ.get("TVM_SOURCE_DIR", None)
138+
else os.environ["TVM_HOME"]
139+
)
140+
include_paths += [
141+
Path(tvm_home).resolve() / "include",
142+
Path(tvm_home).resolve() / "ffi" / "include",
143+
Path(tvm_home).resolve() / "ffi" / "3rdparty" / "dlpack" / "include",
144+
Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include",
145+
]
146+
else:
147+
# If TVM_SOURCE_DIR and TVM_HOME are not set, use the default TVM package path
148+
tvm_package_path = Path(tvm.__file__).resolve().parent
149+
if (tvm_package_path / "include").exists():
150+
# The package is installed from pip.
151+
import tvm_ffi
152+
153+
tvm_ffi_package_path = Path(tvm_ffi.__file__).resolve().parent
154+
include_paths += [
155+
tvm_package_path / "include",
156+
tvm_package_path / "3rdparty" / "dmlc-core" / "include",
157+
tvm_ffi_package_path / "include",
158+
]
159+
elif (tvm_package_path.parent.parent / "include").exists():
160+
# The package is installed from source.
161+
include_paths += [
162+
tvm_package_path.parent.parent / "include",
163+
tvm_package_path.parent.parent / "ffi" / "include",
164+
tvm_package_path.parent.parent / "ffi" / "3rdparty" / "dlpack" / "include",
165+
tvm_package_path.parent.parent / "3rdparty" / "dmlc-core" / "include",
166+
]
167+
else:
168+
# warning: TVM is not installed in the system.
169+
print(
170+
"Warning: Include path for TVM cannot be found. "
171+
"FlashInfer kernel compilation may fail due to missing headers."
172+
)
173+
138174
# ------------------------------------------------------------------------
139175
# 3) Function to compile a single source file
140176
# ------------------------------------------------------------------------

0 commit comments

Comments
 (0)