Skip to content

Commit f171197

Browse files
committed
[ARITH] Allow Analyzer to MarkGlobalPositiveValue
This PR introduces an utility function MarkGlobalPositiveValue. This function allows analyzer to mark buffer shapes in function arguments as positive globally and opens doors for more symbolic simplification.
1 parent e178375 commit f171197

File tree

7 files changed

+89
-12
lines changed

7 files changed

+89
-12
lines changed

include/tvm/arith/analyzer.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,22 @@ class TVM_DLL Analyzer {
618618
TransitiveComparisonAnalyzer transitive_comparisons;
619619
/*! \brief constructor */
620620
Analyzer();
621+
/*!
622+
* \brief Mark the value as positive value globally in analyzer.
623+
*
624+
* Only call this function if the positive condition is global and
625+
* not context-dependent.
626+
*
627+
* This function does best-effort propagations to the sub-analyzers
628+
*
629+
* \note We expose this function because positive global values,
630+
* such as symbolic buffer shapes in function arguments are really
631+
* important to ensure the best simplification, and usually they
632+
* can be handled in a simpler way than the generic constraints.
633+
*
634+
* This function may call into the Update function of the sub-analyzers.
635+
*/
636+
void MarkGlobalPositiveValue(const PrimExpr& value);
621637
/*!
622638
* \brief Notify all the sub-analyzers that var
623639
* is created and binded to expr.

src/arith/analyzer.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <tvm/tir/expr.h>
2626
#include <tvm/tir/op.h>
2727

28+
#include "const_fold.h"
2829
#include "product_normal_form.h"
2930

3031
namespace tvm {
@@ -63,6 +64,35 @@ void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
6364
// skip rewrite simplify
6465
}
6566

67+
void Analyzer::MarkGlobalPositiveValue(const PrimExpr& value) {
68+
// split out the symbolic and non-symbolic part
69+
int64_t cscale = 1;
70+
PrimExpr symbolic = tir::make_const(value.dtype(), 1);
71+
auto fcollect = [&](PrimExpr val) {
72+
if (const auto* intimm = val.as<IntImmNode>()) {
73+
cscale *= intimm->value;
74+
} else {
75+
symbolic = symbolic * val;
76+
}
77+
};
78+
UnpackReduction<tir::MulNode>(value, fcollect);
79+
if (cscale <= 0) return;
80+
// override the constant int bound by marking it as non-negative
81+
// NOTE: there might be future opportunities of more bound hint
82+
// this is a simple step and covers all the current needs
83+
//
84+
// We may consider enhance the sub analyzer to directly take
85+
// MarkPositiveVar so their bounds do not overlap
86+
if (const auto* var_ptr = symbolic.as<VarNode>()) {
87+
Var var = GetRef<Var>(var_ptr);
88+
bool allow_override = true;
89+
// mark the constant bound is sufficient
90+
this->const_int_bound.Update(var, ConstIntBound(1, ConstIntBound::kPosInf), allow_override);
91+
this->int_set.Update(var, IntSet::Interval(tir::make_const(var.dtype(), 1), pos_inf()),
92+
allow_override);
93+
}
94+
}
95+
6696
void Analyzer::Bind(const Map<Var, Range>& variables, bool allow_override) {
6797
for (const auto& iter : variables) {
6898
this->Bind(iter.first, iter.second, allow_override);

src/arith/ir_mutator_with_analyzer.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ namespace arith {
3030

3131
using namespace tir;
3232

33+
void IRMutatorWithAnalyzer::MarkBufferMapShapes(const tir::PrimFunc& func) {
34+
// Mark the all the symbolic buffer shape values in the buffer map as positive value.
35+
for (auto kv : func->buffer_map) {
36+
for (PrimExpr shape : kv.second->shape) {
37+
analyzer_->MarkGlobalPositiveValue(shape);
38+
}
39+
}
40+
}
41+
3342
Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) {
3443
// record the loop variable as iterators
3544
Range dom = Range::FromMinExtent(op->min, op->extent);

src/arith/ir_mutator_with_analyzer.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,14 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
6262
PrimExpr VisitExpr_(const tir::ReduceNode* op) override;
6363

6464
protected:
65+
/*!
66+
* \brief Mark the all the buffer shape values in the buffer map as positive value.
67+
*
68+
* \note call this function before Visit function's body to maximize
69+
* simplification efficiency
70+
*/
71+
void MarkBufferMapShapes(const tir::PrimFunc& func);
72+
6573
/*! \brief internal analyzer field. */
6674
Analyzer* analyzer_;
6775
// the following two fields are useful in case we want

src/tir/transforms/flatten_buffer.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
4242
arith::Analyzer ana;
4343
auto pass = BufferFlattener(&ana);
4444
auto writer = func.CopyOnWrite();
45+
pass.MarkBufferMapShapes(func);
4546
writer->body = pass.VisitStmt(func->body);
4647
// The buffers in func->buffer_map are deliberately left
4748
// unflattened, as they are used for validation of user-provided

src/tir/transforms/simplify.cc

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -142,20 +142,24 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.Simplify", SimplifyConfig);
142142

143143
class StmtSimplifier : public IRMutatorWithAnalyzer {
144144
public:
145-
static Stmt Apply(Stmt stmt, Analyzer* analyzer, Optional<SimplifyConfig> config_opt = NullOpt) {
145+
static PrimFunc Apply(PrimFunc func, Analyzer* analyzer,
146+
Optional<SimplifyConfig> config_opt = NullOpt) {
146147
auto config = config_opt.value_or(AttrsWithDefaultValues<arith::SimplifyConfig>());
147148
analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions());
148149

149150
std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
150151
if (config->propagate_knowns_to_prove_conditional ||
151152
config->propagate_knowns_to_simplify_expressions) {
152-
touch_pattern = ControlFlowGraph(stmt);
153+
touch_pattern = ControlFlowGraph(func->body);
153154
}
154155

155-
std::unordered_set<const VarNode*> used_in_buffer_def = CollectVarsUsedInBufferDefinition(stmt);
156+
std::unordered_set<const VarNode*> used_in_buffer_def =
157+
CollectVarsUsedInBufferDefinition(func->body);
156158
StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern),
157159
std::move(used_in_buffer_def));
158-
return simplifier(std::move(stmt));
160+
simplifier.MarkBufferMapShapes(func);
161+
func.CopyOnWrite()->body = simplifier(func->body);
162+
return func;
159163
}
160164

161165
private:
@@ -335,21 +339,14 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
335339
} // namespace arith
336340

337341
namespace tir {
338-
339-
Stmt Simplify(Stmt stmt, arith::Analyzer* analyzer) {
340-
return arith::StmtSimplifier::Apply(stmt, analyzer);
341-
}
342-
343342
namespace transform {
344343

345344
Pass Simplify() {
346345
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
347346
arith::Analyzer analyzer;
348347
auto cfg = ctx->GetConfig<arith::SimplifyConfig>("tir.Simplify");
349348

350-
auto* n = f.CopyOnWrite();
351-
n->body = arith::StmtSimplifier::Apply(std::move(n->body), &analyzer, cfg);
352-
return f;
349+
return arith::StmtSimplifier::Apply(f, &analyzer, cfg);
353350
};
354351
return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {});
355352
}

tests/python/unittest/test_tir_transform_simplify.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1733,5 +1733,21 @@ def before(A_ptr: T.handle("float32"), A_stride: T.int32):
17331733
expected = before
17341734

17351735

1736+
class TestBufferShapeConstraint(BaseBeforeAfter):
1737+
"""If enabled, rewrite boolean expressions into AND of OR"""
1738+
1739+
convert_boolean_to_and_of_ors = True
1740+
1741+
def before(a: T.handle):
1742+
n = T.int64()
1743+
A = T.match_buffer(a, (n * 32,), "float32")
1744+
A[T.min(T.int64(0), n)] = T.float32(0)
1745+
1746+
def expected(a: T.handle):
1747+
n = T.int64()
1748+
A = T.match_buffer(a, (n * 32,), "float32")
1749+
A[T.int64(0)] = T.float32(0)
1750+
1751+
17361752
if __name__ == "__main__":
17371753
tvm.testing.main()

0 commit comments

Comments
 (0)