Skip to content

Commit 89f9573

Browse files
authored
[TIR] Decouple DeepEqual from StructuralEqual (#18151)
This PR decouples deep equal from structural equal implementation by providing a more direct implementatio through functor. DeepEqual is being used at heart of arith simplification as subroutine and it performs more direct nested checking without doing var remapping as structural equal for efficiency reasons. It also do not need to trace the wrong comparison since the failed path is also expected to happen often. This step likely will improve the deep equal efficiency because of the more direct approach and gives us opportunity to run simplify future refactor of structural equal to focus on struct path tracing.
1 parent 1a1d27c commit 89f9573

File tree

1 file changed

+154
-25
lines changed

1 file changed

+154
-25
lines changed

src/tir/analysis/deep_equal.cc

Lines changed: 154 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,48 +25,177 @@
2525
#include <tvm/ffi/reflection/registry.h>
2626
#include <tvm/node/object_path.h>
2727
#include <tvm/node/reflection.h>
28-
#include <tvm/node/structural_equal.h>
2928
#include <tvm/tir/analysis.h>
29+
#include <tvm/tir/expr_functor.h>
3030

3131
namespace tvm {
3232
namespace tir {
3333

34-
class DeepCmpSEqualHandler : public SEqualReducer::Handler {
34+
#define DEFINE_DEEP_EQUAL_BIN_EXPR(OpNode) \
35+
bool VisitExpr_(const OpNode* plhs, const PrimExpr& rhs) final { \
36+
const auto* prhs = rhs.as<OpNode>(); \
37+
return plhs->dtype == prhs->dtype && VisitExpr(plhs->a, prhs->a) && \
38+
VisitExpr(plhs->b, prhs->b); \
39+
}
40+
41+
#define DEFINE_DEEP_EQUAL_IMM_EXPR(OpNode) \
42+
bool VisitExpr_(const OpNode* plhs, const PrimExpr& rhs) final { \
43+
const auto* prhs = rhs.as<OpNode>(); \
44+
return plhs->dtype == prhs->dtype && plhs->value == prhs->value; \
45+
}
46+
47+
class ExprDeepEqualChecker : private ExprFunctor<bool(const PrimExpr&, const PrimExpr&)> {
3548
public:
36-
// use direct recursion.
37-
bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
38-
const Optional<ObjectPathPair>&) final {
49+
static bool Check(const PrimExpr& lhs, const PrimExpr& rhs) {
50+
// quick path without constructing the object
3951
if (lhs.same_as(rhs)) return true;
4052
if (!lhs.defined() && rhs.defined()) return false;
4153
if (!rhs.defined() && lhs.defined()) return false;
4254
if (lhs->type_index() != rhs->type_index()) return false;
43-
return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, nullptr, false)) &&
44-
!fail_;
55+
if (auto* plhs = lhs.as<IntImmNode>()) {
56+
auto* prhs = rhs.as<IntImmNode>();
57+
return plhs->dtype == prhs->dtype && plhs->value == prhs->value;
58+
}
59+
return ExprDeepEqualChecker().VisitExpr(lhs, rhs);
4560
}
4661

47-
void DeferFail(const ObjectPathPair&) final { fail_ = true; }
48-
bool IsFailDeferralEnabled() final { return false; }
49-
50-
ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { return lhs; }
51-
void MarkGraphNode() final {}
62+
bool VisitExpr(const PrimExpr& lhs, const PrimExpr& rhs) final {
63+
if (lhs.same_as(rhs)) return true;
64+
if (!lhs.defined() && rhs.defined()) return false;
65+
if (!rhs.defined() && lhs.defined()) return false;
66+
if (lhs->type_index() != rhs->type_index()) return false;
67+
return ExprFunctor::VisitExpr(lhs, rhs);
68+
}
5269

5370
private:
54-
// reflection vtable
55-
ReflectionVTable* vtable_ = ReflectionVTable::Global();
56-
bool fail_ = false;
71+
bool ArrayDeepEqual(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) {
72+
if (lhs.size() != rhs.size()) return false;
73+
for (size_t i = 0; i < lhs.size(); i++) {
74+
if (!VisitExpr(lhs[i], rhs[i])) return false;
75+
}
76+
return true;
77+
}
78+
79+
bool ArrayDeepEqual(const Array<IterVar>& lhs, const Array<IterVar>& rhs) {
80+
// for iter var, we require pointer equality
81+
if (lhs.size() != rhs.size()) return false;
82+
for (size_t i = 0; i < lhs.size(); i++) {
83+
if (!lhs[i].same_as(rhs[i])) return true;
84+
}
85+
return true;
86+
}
87+
88+
bool OptionalDeepEqual(const Optional<PrimExpr>& lhs, const Optional<PrimExpr>& rhs) {
89+
if (lhs.same_as(rhs)) return true;
90+
if (!lhs.defined() && rhs.defined()) return false;
91+
if (lhs.defined() && !rhs.defined()) return false;
92+
return VisitExpr(*lhs, *rhs);
93+
}
94+
95+
bool VisitExpr_(const VarNode* plhs, const PrimExpr& rhs) final {
96+
// for var, we require pointer equality
97+
return plhs == rhs.get();
98+
}
99+
100+
bool VisitExpr_(const SizeVarNode* plhs, const PrimExpr& rhs) final {
101+
// for var, we require pointer equality
102+
return plhs == rhs.get();
103+
}
104+
105+
bool VisitExpr_(const BufferLoadNode* plhs, const PrimExpr& rhs) final {
106+
const auto* prhs = rhs.as<BufferLoadNode>();
107+
// we run pointer comparison of the buffer
108+
return plhs->dtype == prhs->dtype && plhs->buffer.same_as(prhs->buffer) &&
109+
ArrayDeepEqual(plhs->indices, prhs->indices) &&
110+
OptionalDeepEqual(plhs->predicate, prhs->predicate);
111+
}
112+
113+
bool VisitExpr_(const ProducerLoadNode* plhs, const PrimExpr& rhs) final {
114+
const auto* prhs = rhs.as<ProducerLoadNode>();
115+
// run shallow pointer comparison of the producer
116+
return plhs->dtype == prhs->dtype && plhs->producer.same_as(prhs->producer) &&
117+
ArrayDeepEqual(plhs->indices, prhs->indices);
118+
}
119+
120+
bool VisitExpr_(const LetNode* plhs, const PrimExpr& rhs) final {
121+
const auto* prhs = rhs.as<LetNode>();
122+
return plhs->dtype == prhs->dtype && VisitExpr(plhs->var, prhs->var) &&
123+
VisitExpr(plhs->value, prhs->value) && VisitExpr(plhs->body, prhs->body);
124+
}
125+
126+
bool VisitExpr_(const CallNode* plhs, const PrimExpr& rhs) final {
127+
const auto* prhs = rhs.as<CallNode>();
128+
return plhs->dtype == prhs->dtype && plhs->op.same_as(prhs->op) &&
129+
ArrayDeepEqual(plhs->args, prhs->args);
130+
}
131+
132+
bool VisitExpr_(const ReduceNode* plhs, const PrimExpr& rhs) final {
133+
const auto* prhs = rhs.as<ReduceNode>();
134+
return plhs->dtype == prhs->dtype && plhs->combiner.same_as(prhs->combiner) &&
135+
ArrayDeepEqual(plhs->source, prhs->source) && ArrayDeepEqual(plhs->init, prhs->init) &&
136+
ArrayDeepEqual(plhs->axis, prhs->axis) && VisitExpr(plhs->condition, prhs->condition) &&
137+
plhs->value_index == prhs->value_index;
138+
}
139+
140+
bool VisitExpr_(const CastNode* plhs, const PrimExpr& rhs) final {
141+
const auto* prhs = rhs.as<CastNode>();
142+
return plhs->dtype == prhs->dtype && VisitExpr(plhs->value, prhs->value);
143+
}
144+
145+
bool VisitExpr_(const NotNode* plhs, const PrimExpr& rhs) final {
146+
const auto* prhs = rhs.as<NotNode>();
147+
return plhs->dtype == prhs->dtype && VisitExpr(plhs->a, prhs->a);
148+
}
149+
150+
bool VisitExpr_(const SelectNode* plhs, const PrimExpr& rhs) final {
151+
const auto* prhs = rhs.as<SelectNode>();
152+
return plhs->dtype == prhs->dtype && VisitExpr(plhs->condition, prhs->condition) &&
153+
VisitExpr(plhs->true_value, prhs->true_value) &&
154+
VisitExpr(plhs->false_value, prhs->false_value);
155+
}
156+
157+
bool VisitExpr_(const RampNode* plhs, const PrimExpr& rhs) final {
158+
const auto* prhs = rhs.as<RampNode>();
159+
return plhs->dtype == prhs->dtype && VisitExpr(plhs->base, prhs->base) &&
160+
VisitExpr(plhs->stride, prhs->stride) && VisitExpr(plhs->lanes, prhs->lanes);
161+
}
162+
163+
bool VisitExpr_(const ShuffleNode* plhs, const PrimExpr& rhs) final {
164+
const auto* prhs = rhs.as<ShuffleNode>();
165+
return plhs->dtype == prhs->dtype && ArrayDeepEqual(plhs->vectors, prhs->vectors) &&
166+
ArrayDeepEqual(plhs->indices, prhs->indices);
167+
}
168+
169+
bool VisitExpr_(const BroadcastNode* plhs, const PrimExpr& rhs) final {
170+
const auto* prhs = rhs.as<BroadcastNode>();
171+
return plhs->dtype == prhs->dtype && VisitExpr(plhs->value, prhs->value) &&
172+
VisitExpr(plhs->lanes, prhs->lanes);
173+
}
174+
175+
DEFINE_DEEP_EQUAL_BIN_EXPR(AddNode)
176+
DEFINE_DEEP_EQUAL_BIN_EXPR(SubNode)
177+
DEFINE_DEEP_EQUAL_BIN_EXPR(MulNode)
178+
DEFINE_DEEP_EQUAL_BIN_EXPR(DivNode)
179+
DEFINE_DEEP_EQUAL_BIN_EXPR(ModNode)
180+
DEFINE_DEEP_EQUAL_BIN_EXPR(FloorDivNode)
181+
DEFINE_DEEP_EQUAL_BIN_EXPR(FloorModNode)
182+
DEFINE_DEEP_EQUAL_BIN_EXPR(MinNode)
183+
DEFINE_DEEP_EQUAL_BIN_EXPR(MaxNode)
184+
DEFINE_DEEP_EQUAL_BIN_EXPR(EQNode)
185+
DEFINE_DEEP_EQUAL_BIN_EXPR(NENode)
186+
DEFINE_DEEP_EQUAL_BIN_EXPR(LTNode)
187+
DEFINE_DEEP_EQUAL_BIN_EXPR(LENode)
188+
DEFINE_DEEP_EQUAL_BIN_EXPR(GTNode)
189+
DEFINE_DEEP_EQUAL_BIN_EXPR(GENode)
190+
DEFINE_DEEP_EQUAL_BIN_EXPR(AndNode)
191+
DEFINE_DEEP_EQUAL_BIN_EXPR(OrNode)
192+
DEFINE_DEEP_EQUAL_IMM_EXPR(IntImmNode)
193+
DEFINE_DEEP_EQUAL_IMM_EXPR(FloatImmNode)
194+
DEFINE_DEEP_EQUAL_IMM_EXPR(StringImmNode)
57195
};
58196

59197
bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
60-
// quick path
61-
if (lhs.same_as(rhs)) return true;
62-
if (!lhs.defined() && rhs.defined()) return false;
63-
if (!rhs.defined() && lhs.defined()) return false;
64-
if (lhs->type_index() != rhs->type_index()) return false;
65-
if (auto* plhs = lhs.as<IntImmNode>()) {
66-
auto* prhs = rhs.as<IntImmNode>();
67-
return plhs->dtype == prhs->dtype && plhs->value == prhs->value;
68-
}
69-
return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false, std::nullopt);
198+
return ExprDeepEqualChecker::Check(lhs, rhs);
70199
}
71200

72201
TVM_FFI_STATIC_INIT_BLOCK({

0 commit comments

Comments
 (0)