Skip to content

Commit 1808a94

Browse files
authored
[Fix] Update FlashInfer JIT header lookup (#18244)
This PR fixes the tvm/dlpack/dmlc header lookup in the FlashInfer kernel JIT compilation. Prior to this fix, the JIT compilation assumes the environment variable `TVM_SOURCE_DIR` is always defined, which is not always true. This PR fixes the behavior and considers multiple cases, including TVM source builds and pip-installed packages.
1 parent 2012d55 commit 1808a94

File tree

2 files changed

+45
-7
lines changed

2 files changed

+45
-7
lines changed

python/tvm/libinfo.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,9 @@ def find_include_path(name=None, search_path=None, optional=False):
195195
include_path : list(string)
196196
List of all found paths to header files.
197197
"""
198-
if os.environ.get("TVM_HOME", None):
198+
if os.environ.get("TVM_SOURCE_DIR", None):
199+
source_dir = os.environ["TVM_SOURCE_DIR"]
200+
elif os.environ.get("TVM_HOME", None):
199201
source_dir = os.environ["TVM_HOME"]
200202
else:
201203
ffi_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
@@ -204,7 +206,7 @@ def find_include_path(name=None, search_path=None, optional=False):
204206
if os.path.isdir(os.path.join(source_dir, "include")):
205207
break
206208
else:
207-
raise AssertionError("Cannot find the source directory given ffi_dir: {ffi_dir}")
209+
raise AssertionError(f"Cannot find the source directory given ffi_dir: {ffi_dir}")
208210
third_party_dir = os.path.join(source_dir, "3rdparty")
209211

210212
header_path = []

python/tvm/relax/backend/cuda/flashinfer.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from pathlib import Path
2525
from typing import List
2626

27+
import tvm_ffi
28+
2729
import tvm
2830
from tvm.target import Target
2931

@@ -124,17 +126,51 @@ def get_object_file_path(src: Path) -> Path:
124126
# ------------------------------------------------------------------------
125127
# 2) Include paths
126128
# ------------------------------------------------------------------------
127-
tvm_home = os.environ["TVM_SOURCE_DIR"]
128129
include_paths = [
129130
FLASHINFER_INCLUDE_DIR,
130131
FLASHINFER_CSRC_DIR,
131132
FLASHINFER_TVM_BINDING_DIR,
132-
Path(tvm_home).resolve() / "include",
133-
Path(tvm_home).resolve() / "ffi" / "include",
134-
Path(tvm_home).resolve() / "ffi" / "3rdparty" / "dlpack" / "include",
135-
Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include",
136133
] + CUTLASS_INCLUDE_DIRS
137134

135+
if os.environ.get("TVM_SOURCE_DIR", None) or os.environ.get("TVM_HOME", None):
136+
# Respect TVM_SOURCE_DIR and TVM_HOME if they are set
137+
tvm_home = (
138+
os.environ["TVM_SOURCE_DIR"]
139+
if os.environ.get("TVM_SOURCE_DIR", None)
140+
else os.environ["TVM_HOME"]
141+
)
142+
include_paths += [
143+
Path(tvm_home).resolve() / "include",
144+
Path(tvm_home).resolve() / "ffi" / "include",
145+
Path(tvm_home).resolve() / "ffi" / "3rdparty" / "dlpack" / "include",
146+
Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include",
147+
]
148+
else:
149+
# If TVM_SOURCE_DIR and TVM_HOME are not set, use the default TVM package path
150+
tvm_package_path = Path(tvm.__file__).resolve().parent
151+
if (tvm_package_path / "include").exists():
152+
# The package is installed from pip.
153+
tvm_ffi_package_path = Path(tvm_ffi.__file__).resolve().parent
154+
include_paths += [
155+
tvm_package_path / "include",
156+
tvm_package_path / "3rdparty" / "dmlc-core" / "include",
157+
tvm_ffi_package_path / "include",
158+
]
159+
elif (tvm_package_path.parent.parent / "include").exists():
160+
# The package is installed from source.
161+
include_paths += [
162+
tvm_package_path.parent.parent / "include",
163+
tvm_package_path.parent.parent / "ffi" / "include",
164+
tvm_package_path.parent.parent / "ffi" / "3rdparty" / "dlpack" / "include",
165+
tvm_package_path.parent.parent / "3rdparty" / "dmlc-core" / "include",
166+
]
167+
else:
168+
# warning: TVM is not installed in the system.
169+
print(
170+
"Warning: Include path for TVM cannot be found. "
171+
"FlashInfer kernel compilation may fail due to missing headers."
172+
)
173+
138174
# ------------------------------------------------------------------------
139175
# 3) Function to compile a single source file
140176
# ------------------------------------------------------------------------

0 commit comments

Comments
 (0)