@@ -126,38 +126,54 @@ PrimExpr PresburgerSetNode::GenerateConstraint() const {
126126 for (const IntegerRelation& disjunct : disjuncts) {
127127 PrimExpr union_entry = Bool (1 );
128128 for (unsigned i = 0 , e = disjunct.getNumEqualities (); i < e; ++i) {
129- PrimExpr linear_eq = IntImm (DataType::Int (32 ), 0 );
129+ PrimExpr linear_eq = IntImm (DataType::Int (64 ), 0 );
130130 if (disjunct.getNumCols () > 1 ) {
131131 for (unsigned j = 0 , f = disjunct.getNumCols () - 1 ; j < f; ++j) {
132+ #if TVM_MLIR_VERSION >= 160
133+ auto coeff = int64_t (disjunct.atEq (i, j));
134+ #else
132135 auto coeff = disjunct.atEq (i, j);
136+ #endif
133137 if (coeff >= 0 || is_zero (linear_eq)) {
134- linear_eq = linear_eq + IntImm (DataType::Int (32 ), coeff) * vars[j];
138+ linear_eq = linear_eq + IntImm (DataType::Int (64 ), coeff) * vars[j];
135139 } else {
136- linear_eq = linear_eq - IntImm (DataType::Int (32 ), -coeff) * vars[j];
140+ linear_eq = linear_eq - IntImm (DataType::Int (64 ), -coeff) * vars[j];
137141 }
138142 }
139143 }
144+ #if TVM_MLIR_VERSION >= 160
145+ auto c0 = int64_t (disjunct.atEq (i, disjunct.getNumCols () - 1 ));
146+ #else
140147 auto c0 = disjunct.atEq (i, disjunct.getNumCols () - 1 );
141- linear_eq = linear_eq + IntImm (DataType::Int (32 ), c0);
148+ #endif
149+ linear_eq = linear_eq + IntImm (DataType::Int (64 ), c0);
142150 union_entry = (union_entry && (linear_eq == 0 ));
143151 }
144152 for (unsigned i = 0 , e = disjunct.getNumInequalities (); i < e; ++i) {
145- PrimExpr linear_eq = IntImm (DataType::Int (32 ), 0 );
153+ PrimExpr linear_eq = IntImm (DataType::Int (64 ), 0 );
146154 if (disjunct.getNumCols () > 1 ) {
147155 for (unsigned j = 0 , f = disjunct.getNumCols () - 1 ; j < f; ++j) {
156+ #if TVM_MLIR_VERSION >= 160
157+ auto coeff = int64_t (disjunct.atIneq (i, j));
158+ #else
148159 auto coeff = disjunct.atIneq (i, j);
160+ #endif
149161 if (coeff >= 0 || is_zero (linear_eq)) {
150- linear_eq = linear_eq + IntImm (DataType::Int (32 ), coeff) * vars[j];
162+ linear_eq = linear_eq + IntImm (DataType::Int (64 ), coeff) * vars[j];
151163 } else {
152- linear_eq = linear_eq - IntImm (DataType::Int (32 ), -coeff) * vars[j];
164+ linear_eq = linear_eq - IntImm (DataType::Int (64 ), -coeff) * vars[j];
153165 }
154166 }
155167 }
168+ #if TVM_MLIR_VERSION >= 160
169+ auto c0 = int64_t (disjunct.atIneq (i, disjunct.getNumCols () - 1 ));
170+ #else
156171 auto c0 = disjunct.atIneq (i, disjunct.getNumCols () - 1 );
172+ #endif
157173 if (c0 >= 0 ) {
158- linear_eq = linear_eq + IntImm (DataType::Int (32 ), c0);
174+ linear_eq = linear_eq + IntImm (DataType::Int (64 ), c0);
159175 } else {
160- linear_eq = linear_eq - IntImm (DataType::Int (32 ), -c0);
176+ linear_eq = linear_eq - IntImm (DataType::Int (64 ), -c0);
161177 }
162178 union_entry = (union_entry && (linear_eq >= 0 ));
163179 }
@@ -199,10 +215,19 @@ PresburgerSet Intersect(const Array<PresburgerSet>& sets) {
199215
200216IntSet EvalSet (const PrimExpr& e, const PresburgerSet& set) {
201217 Array<PrimExpr> tvm_coeffs = DetectLinearEquation (e, set->GetVars ());
218+ #if TVM_MLIR_VERSION >= 160
219+ SmallVector<mlir::presburger::MPInt> coeffs;
220+ #else
202221 SmallVector<int64_t > coeffs;
222+ #endif
223+
203224 coeffs.reserve (tvm_coeffs.size ());
204225 for (const PrimExpr& it : tvm_coeffs) {
226+ #if TVM_MLIR_VERSION >= 160
227+ coeffs.push_back (mlir::presburger::MPInt (*as_const_int (it)));
228+ #else
205229 coeffs.push_back (*as_const_int (it));
230+ #endif
206231 }
207232
208233 IntSet result = IntSet ().Nothing ();
@@ -211,9 +236,17 @@ IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) {
211236 auto range = simplex.computeIntegerBounds (coeffs);
212237 auto maxRoundedDown (simplex.computeOptimum (Simplex::Direction::Up, coeffs));
213238 auto opt = range.first .getOptimumIfBounded ();
239+ #if TVM_MLIR_VERSION >= 160
240+ auto min = opt.has_value () ? IntImm (DataType::Int (64 ), int64_t (opt.value ())) : neg_inf ();
241+ #else
214242 auto min = opt.hasValue () ? IntImm (DataType::Int (64 ), opt.getValue ()) : neg_inf ();
243+ #endif
215244 opt = range.second .getOptimumIfBounded ();
245+ #if TVM_MLIR_VERSION >= 160
246+ auto max = opt.has_value () ? IntImm (DataType::Int (64 ), int64_t (opt.value ())) : pos_inf ();
247+ #else
216248 auto max = opt.hasValue () ? IntImm (DataType::Int (64 ), opt.getValue ()) : pos_inf ();
249+ #endif
217250 auto interval = IntervalSet (min, max);
218251 result = Union ({result, interval});
219252 }
0 commit comments