Skip to content

Commit 022299b

Browse files
authored
[Arith] MLIR PresburgerSet compile fix mlir >= 160 (#15638)
Hi folks, Some fixes for MLIR based analyzer module introduced by #14690 . --- * Make CMake at par with LLVM info: ``` {...} -- Use llvm-config=llvm-config-64 -- LLVM libdir: /usr/lib64 -- Found MLIR -- Build with MLIR -- Set TVM_MLIR_VERSION=160 -- Found LLVM_INCLUDE_DIRS=/usr/include {...} -- USE_MKL : OFF -- USE_MLIR : ON -- USE_MSVC_MT : OFF {...} ``` * Fix several compilation errors: ``` error: cannot convert 'llvm::SmallVector<long int>' to 'llvm::ArrayRef<mlir::presburger::MPInt>' error: no matching function for call to 'tvm::IntImm::IntImm(tvm::runtime::DataType, mlir::presburger::MPInt&)' note: no known conversion for argument 2 from 'mlir::presburger::MPInt' to 'int64_t' {aka 'long int'} ``` Tested using: ```llvm/mlir 16.0.6```, ```llvm/mlir 15.0.7```, ```llvm/mlir 17.0.0rc3```
1 parent 79f9e57 commit 022299b

File tree

5 files changed

+58
-9
lines changed

5 files changed

+58
-9
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ tvm_option(USE_HEXAGON_EXTERNAL_LIBS "Path to git repo containing external Hexag
5454
tvm_option(USE_RPC "Build with RPC" ON)
5555
tvm_option(USE_THREADS "Build with thread support" ON)
5656
tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" OFF)
57+
tvm_option(USE_MLIR "Build with MLIR support" OFF)
5758
tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF)
5859
tvm_option(USE_GRAPH_EXECUTOR "Build with tiny graph executor" ON)
5960
tvm_option(USE_GRAPH_EXECUTOR_CUDA_GRAPH "Build with tiny graph executor with CUDA Graph for GPUs" OFF)

cmake/modules/LibInfo.cmake

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ function(add_lib_info src_file)
2424
else()
2525
string(STRIP ${TVM_INFO_LLVM_VERSION} TVM_INFO_LLVM_VERSION)
2626
endif()
27+
if (NOT DEFINED TVM_INFO_MLIR_VERSION)
28+
set(TVM_INFO_MLIR_VERSION "NOT-FOUND")
29+
else()
30+
string(STRIP ${TVM_INFO_MLIR_VERSION} TVM_INFO_MLIR_VERSION)
31+
endif()
2732
if (NOT DEFINED CUDA_VERSION)
2833
set(TVM_INFO_CUDA_VERSION "NOT-FOUND")
2934
else()
@@ -47,6 +52,7 @@ function(add_lib_info src_file)
4752
TVM_INFO_INDEX_DEFAULT_I64="${INDEX_DEFAULT_I64}"
4853
TVM_INFO_INSTALL_DEV="${INSTALL_DEV}"
4954
TVM_INFO_LLVM_VERSION="${TVM_INFO_LLVM_VERSION}"
55+
TVM_INFO_MLIR_VERSION="${TVM_INFO_MLIR_VERSION}"
5056
TVM_INFO_PICOJSON_PATH="${PICOJSON_PATH}"
5157
TVM_INFO_RANG_PATH="${RANG_PATH}"
5258
TVM_INFO_ROCM_PATH="${ROCM_PATH}"
@@ -86,6 +92,7 @@ function(add_lib_info src_file)
8692
TVM_INFO_USE_LIBBACKTRACE="${USE_LIBBACKTRACE}"
8793
TVM_INFO_USE_LIBTORCH="${USE_LIBTORCH}"
8894
TVM_INFO_USE_LLVM="${USE_LLVM}"
95+
TVM_INFO_USE_MLIR="${USE_MLIR}"
8996
TVM_INFO_USE_METAL="${USE_METAL}"
9097
TVM_INFO_USE_MICRO_STANDALONE_RUNTIME="${USE_MICRO_STANDALONE_RUNTIME}"
9198
TVM_INFO_USE_MICRO="${USE_MICRO}"

cmake/utils/FindLLVM.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ macro(find_llvm use_llvm)
150150
list(APPEND LLVM_LIBS "${__llvm_libdir}/libMLIRPresburger.a")
151151
list(APPEND LLVM_LIBS "${__llvm_libdir}/libMLIRSupport.a")
152152
set(TVM_MLIR_VERSION ${TVM_LLVM_VERSION})
153+
message(STATUS "Build with MLIR")
154+
message(STATUS "Set TVM_MLIR_VERSION=" ${TVM_MLIR_VERSION})
153155
endif()
154156
endif()
155157
endif()

src/arith/presburger_set.cc

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,38 +126,54 @@ PrimExpr PresburgerSetNode::GenerateConstraint() const {
126126
for (const IntegerRelation& disjunct : disjuncts) {
127127
PrimExpr union_entry = Bool(1);
128128
for (unsigned i = 0, e = disjunct.getNumEqualities(); i < e; ++i) {
129-
PrimExpr linear_eq = IntImm(DataType::Int(32), 0);
129+
PrimExpr linear_eq = IntImm(DataType::Int(64), 0);
130130
if (disjunct.getNumCols() > 1) {
131131
for (unsigned j = 0, f = disjunct.getNumCols() - 1; j < f; ++j) {
132+
#if TVM_MLIR_VERSION >= 160
133+
auto coeff = int64_t(disjunct.atEq(i, j));
134+
#else
132135
auto coeff = disjunct.atEq(i, j);
136+
#endif
133137
if (coeff >= 0 || is_zero(linear_eq)) {
134-
linear_eq = linear_eq + IntImm(DataType::Int(32), coeff) * vars[j];
138+
linear_eq = linear_eq + IntImm(DataType::Int(64), coeff) * vars[j];
135139
} else {
136-
linear_eq = linear_eq - IntImm(DataType::Int(32), -coeff) * vars[j];
140+
linear_eq = linear_eq - IntImm(DataType::Int(64), -coeff) * vars[j];
137141
}
138142
}
139143
}
144+
#if TVM_MLIR_VERSION >= 160
145+
auto c0 = int64_t(disjunct.atEq(i, disjunct.getNumCols() - 1));
146+
#else
140147
auto c0 = disjunct.atEq(i, disjunct.getNumCols() - 1);
141-
linear_eq = linear_eq + IntImm(DataType::Int(32), c0);
148+
#endif
149+
linear_eq = linear_eq + IntImm(DataType::Int(64), c0);
142150
union_entry = (union_entry && (linear_eq == 0));
143151
}
144152
for (unsigned i = 0, e = disjunct.getNumInequalities(); i < e; ++i) {
145-
PrimExpr linear_eq = IntImm(DataType::Int(32), 0);
153+
PrimExpr linear_eq = IntImm(DataType::Int(64), 0);
146154
if (disjunct.getNumCols() > 1) {
147155
for (unsigned j = 0, f = disjunct.getNumCols() - 1; j < f; ++j) {
156+
#if TVM_MLIR_VERSION >= 160
157+
auto coeff = int64_t(disjunct.atIneq(i, j));
158+
#else
148159
auto coeff = disjunct.atIneq(i, j);
160+
#endif
149161
if (coeff >= 0 || is_zero(linear_eq)) {
150-
linear_eq = linear_eq + IntImm(DataType::Int(32), coeff) * vars[j];
162+
linear_eq = linear_eq + IntImm(DataType::Int(64), coeff) * vars[j];
151163
} else {
152-
linear_eq = linear_eq - IntImm(DataType::Int(32), -coeff) * vars[j];
164+
linear_eq = linear_eq - IntImm(DataType::Int(64), -coeff) * vars[j];
153165
}
154166
}
155167
}
168+
#if TVM_MLIR_VERSION >= 160
169+
auto c0 = int64_t(disjunct.atIneq(i, disjunct.getNumCols() - 1));
170+
#else
156171
auto c0 = disjunct.atIneq(i, disjunct.getNumCols() - 1);
172+
#endif
157173
if (c0 >= 0) {
158-
linear_eq = linear_eq + IntImm(DataType::Int(32), c0);
174+
linear_eq = linear_eq + IntImm(DataType::Int(64), c0);
159175
} else {
160-
linear_eq = linear_eq - IntImm(DataType::Int(32), -c0);
176+
linear_eq = linear_eq - IntImm(DataType::Int(64), -c0);
161177
}
162178
union_entry = (union_entry && (linear_eq >= 0));
163179
}
@@ -199,10 +215,19 @@ PresburgerSet Intersect(const Array<PresburgerSet>& sets) {
199215

200216
IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) {
201217
Array<PrimExpr> tvm_coeffs = DetectLinearEquation(e, set->GetVars());
218+
#if TVM_MLIR_VERSION >= 160
219+
SmallVector<mlir::presburger::MPInt> coeffs;
220+
#else
202221
SmallVector<int64_t> coeffs;
222+
#endif
223+
203224
coeffs.reserve(tvm_coeffs.size());
204225
for (const PrimExpr& it : tvm_coeffs) {
226+
#if TVM_MLIR_VERSION >= 160
227+
coeffs.push_back(mlir::presburger::MPInt(*as_const_int(it)));
228+
#else
205229
coeffs.push_back(*as_const_int(it));
230+
#endif
206231
}
207232

208233
IntSet result = IntSet().Nothing();
@@ -211,9 +236,17 @@ IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) {
211236
auto range = simplex.computeIntegerBounds(coeffs);
212237
auto maxRoundedDown(simplex.computeOptimum(Simplex::Direction::Up, coeffs));
213238
auto opt = range.first.getOptimumIfBounded();
239+
#if TVM_MLIR_VERSION >= 160
240+
auto min = opt.has_value() ? IntImm(DataType::Int(64), int64_t(opt.value())) : neg_inf();
241+
#else
214242
auto min = opt.hasValue() ? IntImm(DataType::Int(64), opt.getValue()) : neg_inf();
243+
#endif
215244
opt = range.second.getOptimumIfBounded();
245+
#if TVM_MLIR_VERSION >= 160
246+
auto max = opt.has_value() ? IntImm(DataType::Int(64), int64_t(opt.value())) : pos_inf();
247+
#else
216248
auto max = opt.hasValue() ? IntImm(DataType::Int(64), opt.getValue()) : pos_inf();
249+
#endif
217250
auto interval = IntervalSet(min, max);
218251
result = Union({result, interval});
219252
}

src/support/libinfo.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
#define TVM_INFO_LLVM_VERSION "NOT-FOUND"
3232
#endif
3333

34+
#ifndef TVM_INFO_MLIR_VERSION
35+
#define TVM_INFO_MLIR_VERSION "NOT-FOUND"
36+
#endif
37+
3438
#ifndef TVM_INFO_USE_CUDA
3539
#define TVM_INFO_USE_CUDA "NOT-FOUND"
3640
#endif
@@ -271,6 +275,7 @@ TVM_DLL Map<String, String> GetLibInfo() {
271275
{"INDEX_DEFAULT_I64", TVM_INFO_INDEX_DEFAULT_I64},
272276
{"INSTALL_DEV", TVM_INFO_INSTALL_DEV},
273277
{"LLVM_VERSION", TVM_INFO_LLVM_VERSION},
278+
{"MLIR_VERSION", TVM_INFO_MLIR_VERSION},
274279
{"PICOJSON_PATH", TVM_INFO_PICOJSON_PATH},
275280
{"RANG_PATH", TVM_INFO_RANG_PATH},
276281
{"ROCM_PATH", TVM_INFO_ROCM_PATH},
@@ -311,6 +316,7 @@ TVM_DLL Map<String, String> GetLibInfo() {
311316
{"USE_LIBBACKTRACE", TVM_INFO_USE_LIBBACKTRACE},
312317
{"USE_LIBTORCH", TVM_INFO_USE_LIBTORCH},
313318
{"USE_LLVM", TVM_INFO_USE_LLVM},
319+
{"USE_MLIR", TVM_INFO_USE_MLIR},
314320
{"USE_METAL", TVM_INFO_USE_METAL},
315321
{"USE_MICRO_STANDALONE_RUNTIME", TVM_INFO_USE_MICRO_STANDALONE_RUNTIME},
316322
{"USE_MICRO", TVM_INFO_USE_MICRO},

0 commit comments

Comments
 (0)