@@ -124,17 +124,53 @@ def get_object_file_path(src: Path) -> Path:
124124 # ------------------------------------------------------------------------
125125 # 2) Include paths
126126 # ------------------------------------------------------------------------
127- tvm_home = os .environ ["TVM_SOURCE_DIR" ]
128127 include_paths = [
129128 FLASHINFER_INCLUDE_DIR ,
130129 FLASHINFER_CSRC_DIR ,
131130 FLASHINFER_TVM_BINDING_DIR ,
132- Path (tvm_home ).resolve () / "include" ,
133- Path (tvm_home ).resolve () / "ffi" / "include" ,
134- Path (tvm_home ).resolve () / "ffi" / "3rdparty" / "dlpack" / "include" ,
135- Path (tvm_home ).resolve () / "3rdparty" / "dmlc-core" / "include" ,
136131 ] + CUTLASS_INCLUDE_DIRS
137132
133+ if os .environ .get ("TVM_SOURCE_DIR" , None ) or os .environ .get ("TVM_HOME" , None ):
134+ # Respect TVM_SOURCE_DIR and TVM_HOME if they are set
135+ tvm_home = (
136+ os .environ ["TVM_SOURCE_DIR" ]
137+ if os .environ .get ("TVM_SOURCE_DIR" , None )
138+ else os .environ ["TVM_HOME" ]
139+ )
140+ include_paths += [
141+ Path (tvm_home ).resolve () / "include" ,
142+ Path (tvm_home ).resolve () / "ffi" / "include" ,
143+ Path (tvm_home ).resolve () / "ffi" / "3rdparty" / "dlpack" / "include" ,
144+ Path (tvm_home ).resolve () / "3rdparty" / "dmlc-core" / "include" ,
145+ ]
146+ else :
147+ # If TVM_SOURCE_DIR and TVM_HOME are not set, use the default TVM package path
148+ tvm_package_path = Path (tvm .__file__ ).resolve ().parent
149+ if (tvm_package_path / "include" ).exists ():
150+ # The package is installed from pip.
151+ import tvm_ffi
152+
153+ tvm_ffi_package_path = Path (tvm_ffi .__file__ ).resolve ().parent
154+ include_paths += [
155+ tvm_package_path / "include" ,
156+ tvm_package_path / "3rdparty" / "dmlc-core" / "include" ,
157+ tvm_ffi_package_path / "include" ,
158+ ]
159+ elif (tvm_package_path .parent .parent / "include" ).exists ():
160+ # The package is installed from source.
161+ include_paths += [
162+ tvm_package_path .parent .parent / "include" ,
163+ tvm_package_path .parent .parent / "ffi" / "include" ,
164+ tvm_package_path .parent .parent / "ffi" / "3rdparty" / "dlpack" / "include" ,
165+ tvm_package_path .parent .parent / "3rdparty" / "dmlc-core" / "include" ,
166+ ]
167+ else :
168+ # warning: TVM is not installed in the system.
169+ print (
170+ "Warning: Include path for TVM cannot be found. "
171+ "FlashInfer kernel compilation may fail due to missing headers."
172+ )
173+
138174 # ------------------------------------------------------------------------
139175 # 3) Function to compile a single source file
140176 # ------------------------------------------------------------------------
0 commit comments