|
24 | 24 | from pathlib import Path |
25 | 25 | from typing import List |
26 | 26 |
|
| 27 | +import tvm_ffi |
| 28 | + |
27 | 29 | import tvm |
28 | 30 | from tvm.target import Target |
29 | 31 |
|
@@ -124,17 +126,51 @@ def get_object_file_path(src: Path) -> Path: |
124 | 126 | # ------------------------------------------------------------------------ |
125 | 127 | # 2) Include paths |
126 | 128 | # ------------------------------------------------------------------------ |
127 | | - tvm_home = os.environ["TVM_SOURCE_DIR"] |
128 | 129 | include_paths = [ |
129 | 130 | FLASHINFER_INCLUDE_DIR, |
130 | 131 | FLASHINFER_CSRC_DIR, |
131 | 132 | FLASHINFER_TVM_BINDING_DIR, |
132 | | - Path(tvm_home).resolve() / "include", |
133 | | - Path(tvm_home).resolve() / "ffi" / "include", |
134 | | - Path(tvm_home).resolve() / "ffi" / "3rdparty" / "dlpack" / "include", |
135 | | - Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include", |
136 | 133 | ] + CUTLASS_INCLUDE_DIRS |
137 | 134 |
|
| 135 | + if os.environ.get("TVM_SOURCE_DIR", None) or os.environ.get("TVM_HOME", None): |
| 136 | + # Respect TVM_SOURCE_DIR and TVM_HOME if they are set |
| 137 | + tvm_home = ( |
| 138 | + os.environ["TVM_SOURCE_DIR"] |
| 139 | + if os.environ.get("TVM_SOURCE_DIR", None) |
| 140 | + else os.environ["TVM_HOME"] |
| 141 | + ) |
| 142 | + include_paths += [ |
| 143 | + Path(tvm_home).resolve() / "include", |
| 144 | + Path(tvm_home).resolve() / "ffi" / "include", |
| 145 | + Path(tvm_home).resolve() / "ffi" / "3rdparty" / "dlpack" / "include", |
| 146 | + Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include", |
| 147 | + ] |
| 148 | + else: |
| 149 | + # If TVM_SOURCE_DIR and TVM_HOME are not set, use the default TVM package path |
| 150 | + tvm_package_path = Path(tvm.__file__).resolve().parent |
| 151 | + if (tvm_package_path / "include").exists(): |
| 152 | + # The package is installed from pip. |
| 153 | + tvm_ffi_package_path = Path(tvm_ffi.__file__).resolve().parent |
| 154 | + include_paths += [ |
| 155 | + tvm_package_path / "include", |
| 156 | + tvm_package_path / "3rdparty" / "dmlc-core" / "include", |
| 157 | + tvm_ffi_package_path / "include", |
| 158 | + ] |
| 159 | + elif (tvm_package_path.parent.parent / "include").exists(): |
| 160 | + # The package is installed from source. |
| 161 | + include_paths += [ |
| 162 | + tvm_package_path.parent.parent / "include", |
| 163 | + tvm_package_path.parent.parent / "ffi" / "include", |
| 164 | + tvm_package_path.parent.parent / "ffi" / "3rdparty" / "dlpack" / "include", |
| 165 | + tvm_package_path.parent.parent / "3rdparty" / "dmlc-core" / "include", |
| 166 | + ] |
| 167 | + else: |
| 168 | + # warning: TVM is not installed in the system. |
| 169 | + print( |
| 170 | + "Warning: Include path for TVM cannot be found. " |
| 171 | + "FlashInfer kernel compilation may fail due to missing headers." |
| 172 | + ) |
| 173 | + |
138 | 174 | # ------------------------------------------------------------------------ |
139 | 175 | # 3) Function to compile a single source file |
140 | 176 | # ------------------------------------------------------------------------ |
|
0 commit comments