Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 58 additions & 3 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ struct ConstIntBoundAnalyzer::Entry {
class ConstIntBoundAnalyzer::Impl
: public ExprFunctor<ConstIntBoundAnalyzer::Entry(const PrimExpr&)> {
public:
explicit Impl(Analyzer* parent) : parent_(parent) {}
/*! \brief additional bound info about expr in bound */
struct BoundInfo {
/*! \brief The expr */
Expand Down Expand Up @@ -129,8 +130,7 @@ class ConstIntBoundAnalyzer::Impl
auto it = var_map_.find(var);
if (it != var_map_.end()) {
ICHECK(it->second == info)
<< "Trying to update var \'" << var << "\'"
<< " with a different const bound: "
<< "Trying to update var \'" << var << "\'" << " with a different const bound: "
<< "original=" << ConstIntBound(it->second.min_value, it->second.max_value)
<< ", new=" << ConstIntBound(info.min_value, info.max_value);
}
Expand Down Expand Up @@ -278,6 +278,25 @@ class ConstIntBoundAnalyzer::Impl

if (b.min_value > 0) {
int64_t b_max_cap = InfAwareAdd(b.max_value, -1);

// Try to get tighter bounds using modular set information
if (parent_ && b.min_value == b.max_value) {
ModularSet mod_a = parent_->modular_set(op->a);
int64_t modulus = b.min_value;
int64_t gcd_coeff_mod = ComputeGCD(mod_a->coeff, modulus);

// If gcd_coeff_mod > 1, we can get tighter bounds
// The result will be of the form gcd_coeff_mod * k + (base % modulus)
// where k ranges to cover [0, modulus - gcd_coeff_mod]
if (gcd_coeff_mod > 1) {
int64_t base_mod = mod_a->base % modulus;
if (base_mod < 0) base_mod += modulus;
int64_t tight_max = modulus - gcd_coeff_mod + base_mod;
if (tight_max >= modulus) tight_max -= modulus;
return MakeBound(base_mod, tight_max);
}
}

if (a.min_value >= 0) {
// 0 <= [a_min, a_max] < b_min
if (a.max_value < b.min_value) return a;
Expand Down Expand Up @@ -324,6 +343,24 @@ class ConstIntBoundAnalyzer::Impl

if (b.min_value > 0) {
int64_t b_max_cap = InfAwareAdd(b.max_value, -1);
// Try to get tighter bounds using modular set information
if (parent_ && b.min_value == b.max_value) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we have the bound analysis already in IntervalSet, is the const int bound still necessary? just want to get a sense of if we need to introduce tihs bound

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from my understanding, const int bound is faster to analysis, and IntervalSet can build on those constant bounds for further analysis. keeping them separate makes the design clearer in my view.

ModularSet mod_a = parent_->modular_set(op->a);
int64_t modulus = b.min_value;
int64_t gcd_coeff_mod = ComputeGCD(mod_a->coeff, modulus);

// If gcd_coeff_mod > 1, we can get tighter bounds
// The result will be of the form gcd_coeff_mod * k + (base % modulus)
// where k ranges to cover [0, modulus - gcd_coeff_mod]
if (gcd_coeff_mod > 1) {
int64_t base_mod = mod_a->base % modulus;
if (base_mod < 0) base_mod += modulus;
int64_t tight_max = modulus - gcd_coeff_mod + base_mod;
if (tight_max >= modulus) tight_max -= modulus;
return MakeBound(base_mod, tight_max);
}
}

if (a.min_value >= 0) {
// 0 <= [a_min, a_max] < b_min
if (a.max_value < b.min_value) return a;
Expand Down Expand Up @@ -458,6 +495,8 @@ class ConstIntBoundAnalyzer::Impl

private:
friend class ConstIntBoundAnalyzer;
// parent analyzer
Analyzer* parent_;
// internal variable map
std::unordered_map<Var, Entry> var_map_;
// additional bound info
Expand Down Expand Up @@ -525,6 +564,22 @@ class ConstIntBoundAnalyzer::Impl
// If the range of b does not have 0, use BinaryOpBoundary.
return BinaryOpBoundary(a, b, op);
}
/*!
* \brief Compute GCD of two integers.
* \param a The first integer.
* \param b The second integer.
* \return the result.
*/
static int64_t ComputeGCD(int64_t a, int64_t b) {
a = std::abs(a);
b = std::abs(b);
while (b != 0) {
int64_t temp = b;
b = a % b;
a = temp;
}
return a;
}
/*!
* \brief Compute x + y, aware of inf.
* \param x The left operand.
Expand Down Expand Up @@ -805,7 +860,7 @@ std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& con
return impl_->EnterConstraint(constraint);
}

ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new Impl()) {}
ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {}

ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { delete impl_; }

Expand Down
92 changes: 66 additions & 26 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/op.h>

#include <algorithm>
#include <unordered_map>
Expand Down Expand Up @@ -111,8 +112,9 @@ TVM_DECLARE_LOGICAL_OP(Not);
* \brief Combine two interval set under arithmetic operations.
* \note this can possibly relax the set.
*/
template <typename Op>
inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, DataType dtype) {
template <typename Op, typename OpNode>
inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, const OpNode* op) {
DataType dtype = op->dtype;
if (a->IsSinglePoint() && b->IsSinglePoint()) {
PrimExpr expr;
if (auto res = TryConstFold<Op>(a->min_value, b->min_value)) {
Expand All @@ -134,7 +136,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, Dat

template <>
inline IntervalSet Combine<tir::Add>(Analyzer* analyer, IntervalSet a, IntervalSet b,
DataType /* dtype */) {
const tir::AddNode* /* op */) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value + b->min_value);
}
Expand All @@ -149,7 +151,7 @@ inline IntervalSet Combine<tir::Add>(Analyzer* analyer, IntervalSet a, IntervalS

template <>
inline IntervalSet Combine<tir::Sub>(Analyzer* analyer, IntervalSet a, IntervalSet b,
DataType /* dtype */) {
const tir::SubNode* /* op */) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value - b->min_value);
}
Expand All @@ -164,7 +166,7 @@ inline IntervalSet Combine<tir::Sub>(Analyzer* analyer, IntervalSet a, IntervalS

template <>
inline IntervalSet Combine<tir::Mul>(Analyzer* analyzer, IntervalSet a, IntervalSet b,
DataType /* dtype */) {
const tir::MulNode* /* op */) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value * b->min_value);
}
Expand Down Expand Up @@ -198,7 +200,7 @@ inline IntervalSet Combine<tir::Mul>(Analyzer* analyzer, IntervalSet a, Interval

template <>
inline IntervalSet Combine<tir::Div>(Analyzer* analyzer, IntervalSet a, IntervalSet b,
DataType /* dtype */) {
const tir::DivNode* /* op */) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value / b->min_value);
}
Expand Down Expand Up @@ -232,7 +234,7 @@ inline IntervalSet Combine<tir::Div>(Analyzer* analyzer, IntervalSet a, Interval

template <>
inline IntervalSet Combine<tir::Mod>(Analyzer* analyzer, IntervalSet a, IntervalSet b,
DataType /* dtype */) {
const tir::ModNode* op) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value));
}
Expand Down Expand Up @@ -261,7 +263,7 @@ inline IntervalSet Combine<tir::Mod>(Analyzer* analyzer, IntervalSet a, Interval

template <>
inline IntervalSet Combine<tir::FloorDiv>(Analyzer* analyzer, IntervalSet a, IntervalSet b,
DataType /* dtype */) {
const tir::FloorDivNode* /* op */) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value));
}
Expand Down Expand Up @@ -295,7 +297,7 @@ inline IntervalSet Combine<tir::FloorDiv>(Analyzer* analyzer, IntervalSet a, Int

template <>
inline IntervalSet Combine<tir::FloorMod>(Analyzer* analyzer, IntervalSet a, IntervalSet b,
DataType /* dtype */) {
const tir::FloorModNode* op) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value));
}
Expand All @@ -321,6 +323,39 @@ inline IntervalSet Combine<tir::FloorMod>(Analyzer* analyzer, IntervalSet a, Int
return IntervalSet(tmin, tmax);
}
}
// Enhanced: Use ModularSet analysis for better bounds
if (auto* div_imm = divisor.as<tir::IntImmNode>()) {
int64_t div_val = div_imm->value;

// Analyze the modular properties of the dividend
ModularSet dividend_mod = analyzer->modular_set(op->a);

if (dividend_mod.defined() && dividend_mod->coeff > 0) {
// Calculate GCD of dividend coefficient and divisor
int64_t gcd = 1;
if (dividend_mod->coeff != 0 && div_val != 0) {
int64_t a_coeff = std::abs(dividend_mod->coeff);
int64_t b_val = std::abs(div_val);
while (b_val != 0) {
int64_t temp = b_val;
b_val = a_coeff % b_val;
a_coeff = temp;
}
gcd = a_coeff;
}

if (gcd > 1 && div_val % gcd == 0) {
// The dividend is a multiple of gcd, and divisor is also a multiple of gcd
// So the result is also a multiple of gcd, with max value = (div_val/gcd - 1) * gcd
int64_t max_quotient = (div_val / gcd) - 1;
int64_t max_mod_result = max_quotient * gcd + (dividend_mod->base % gcd);

if (max_mod_result >= 0 && max_mod_result < div_val) {
return IntervalSet(make_zero(op->dtype), make_const(op->dtype, max_mod_result));
}
}
}
}
return IntervalSet(make_zero(divisor.dtype()), divisor - 1);
} else {
PrimExpr bound = abs(divisor) - 1;
Expand All @@ -333,7 +368,7 @@ inline IntervalSet Combine<tir::FloorMod>(Analyzer* analyzer, IntervalSet a, Int

template <>
inline IntervalSet Combine<tir::Max>(Analyzer* analzyer, IntervalSet a, IntervalSet b,
DataType /* dtype */) {
const tir::MaxNode* /* op */) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(max(a->min_value, b->min_value));
}
Expand All @@ -344,7 +379,7 @@ inline IntervalSet Combine<tir::Max>(Analyzer* analzyer, IntervalSet a, Interval

template <>
inline IntervalSet Combine<tir::Min>(Analyzer* analzyer, IntervalSet a, IntervalSet b,
DataType /* dtype */) {
const tir::MinNode* /* op */) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(min(a->min_value, b->min_value));
}
Expand Down Expand Up @@ -475,19 +510,25 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
if (op->lanes->IsInstance<IntImmNode>()) {
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
if (vstride > 0) {
return Combine<Add>(analyzer_, base,
IntervalSet(make_zero(t), make_const(t, vstride * (lanes - 1))),
op->dtype);
PrimExpr stride_expr = make_const(t, vstride * (lanes - 1));
auto add_op = tir::Add(op->base, stride_expr);
auto add_node = add_op.as<tir::AddNode>();
return Combine<Add>(analyzer_, base, IntervalSet(make_zero(t), stride_expr), add_node);
} else {
return Combine<Add>(analyzer_, base,
IntervalSet(make_const(t, vstride * (lanes - 1)), make_zero(t)),
op->dtype);
PrimExpr stride_expr = make_const(t, vstride * (lanes - 1));
auto add_op = tir::Add(op->base, stride_expr);
auto add_node = add_op.as<tir::AddNode>();
return Combine<Add>(analyzer_, base, IntervalSet(stride_expr, make_zero(t)), add_node);
}
} else { /* Scalable vector */
if (vstride > 0) {
return Combine<Add>(analyzer_, base, IntervalSet(make_zero(t), pos_inf()), op->dtype);
auto add_op = tir::Add(op->base, make_zero(t));
auto add_node = add_op.as<tir::AddNode>();
return Combine<Add>(analyzer_, base, IntervalSet(make_zero(t), pos_inf()), add_node);
} else {
return Combine<Add>(analyzer_, base, IntervalSet(neg_inf(), make_zero(t)), op->dtype);
auto add_op = tir::Add(op->base, make_zero(t));
auto add_node = add_op.as<tir::AddNode>();
return Combine<Add>(analyzer_, base, IntervalSet(neg_inf(), make_zero(t)), add_node);
}
}
}
Expand Down Expand Up @@ -563,7 +604,7 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) {
return IntervalSet::SinglePoint(ffi::GetRef<PrimExpr>(op));
}
return Combine<TOp>(analyzer_, a, b, op->dtype);
return Combine<TOp>(analyzer_, a, b, op);
}

// recursive depth
Expand Down Expand Up @@ -640,13 +681,13 @@ void IntSetAnalyzer::Impl::Update(const Var& var, const IntSet& info, bool can_o

ICHECK(ExprDeepEqual()(old_info.min(), info.min()))
<< "Trying to update var \'" << var << "\'"
<< " with a different minimum value: "
<< "original=" << old_info.min() << ", new=" << info.min();
<< " with a different minimum value: " << "original=" << old_info.min()
<< ", new=" << info.min();

ICHECK(ExprDeepEqual()(old_info.max(), info.max()))
<< "Trying to update var \'" << var << "\'"
<< " with a different maximum value: "
<< "original=" << old_info.max() << ", new=" << info.max();
<< " with a different maximum value: " << "original=" << old_info.max()
<< ", new=" << info.max();
}
}
dom_map_.Set(var, info);
Expand Down Expand Up @@ -1194,8 +1235,7 @@ ffi::Array<IntSet> EstimateRegionUpperBound(const ffi::Array<Range>& region,
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntervalSetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntervalSetNode*>(node.get());
p->stream << "IntervalSet"
<< "[" << op->min_value << ", " << op->max_value << ']';
p->stream << "IntervalSet" << "[" << op->min_value << ", " << op->max_value << ']';
});

TVM_FFI_STATIC_INIT_BLOCK() {
Expand Down
12 changes: 12 additions & 0 deletions tests/python/arith/test_arith_const_int_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,5 +298,17 @@ class TestRampBound(BaseCompare):
)


class TestModularSetBound(BaseCompare):
analyzer = tvm.arith.Analyzer()
tx = tvm.te.var("tx", dtype="int32")
bx = tvm.te.var("bx", dtype="int32")

expr = (bx * 2048 + tx * 16) % 7168

test_case = tvm.testing.parameter(
TestCase(expr, (0, 7152), {bx: (0, 3584), tx: (0, 128)}),
)


if __name__ == "__main__":
tvm.testing.main()
10 changes: 10 additions & 0 deletions tests/python/arith/test_arith_intset.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,5 +387,15 @@ def test_union_lower_bound():
assert result.max_value.same_as(pos_inf)


def test_modular_set():
ck = IntSetChecker()
x = tvm.te.var("x", dtype="int32")
y = tvm.te.var("y", dtype="int32")
expr = (x * 2048 + y * 16) % 7168
ck.verify(
expr, {x: tvm.arith.IntervalSet(0, 128), y: tvm.arith.IntervalSet(0, 3584)}, (0, 7152)
)


if __name__ == "__main__":
tvm.testing.main()
Loading