|
25 | 25 | #include <tvm/ffi/reflection/registry.h> |
26 | 26 | #include <tvm/node/object_path.h> |
27 | 27 | #include <tvm/node/reflection.h> |
28 | | -#include <tvm/node/structural_equal.h> |
29 | 28 | #include <tvm/tir/analysis.h> |
| 29 | +#include <tvm/tir/expr_functor.h> |
30 | 30 |
|
31 | 31 | namespace tvm { |
32 | 32 | namespace tir { |
33 | 33 |
|
34 | | -class DeepCmpSEqualHandler : public SEqualReducer::Handler { |
| 34 | +#define DEFINE_DEEP_EQUAL_BIN_EXPR(OpNode) \ |
| 35 | + bool VisitExpr_(const OpNode* plhs, const PrimExpr& rhs) final { \ |
| 36 | + const auto* prhs = rhs.as<OpNode>(); \ |
| 37 | + return plhs->dtype == prhs->dtype && VisitExpr(plhs->a, prhs->a) && \ |
| 38 | + VisitExpr(plhs->b, prhs->b); \ |
| 39 | + } |
| 40 | + |
| 41 | +#define DEFINE_DEEP_EQUAL_IMM_EXPR(OpNode) \ |
| 42 | + bool VisitExpr_(const OpNode* plhs, const PrimExpr& rhs) final { \ |
| 43 | + const auto* prhs = rhs.as<OpNode>(); \ |
| 44 | + return plhs->dtype == prhs->dtype && plhs->value == prhs->value; \ |
| 45 | + } |
| 46 | + |
| 47 | +class ExprDeepEqualChecker : private ExprFunctor<bool(const PrimExpr&, const PrimExpr&)> { |
35 | 48 | public: |
36 | | - // use direct recursion. |
37 | | - bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, |
38 | | - const Optional<ObjectPathPair>&) final { |
| 49 | + static bool Check(const PrimExpr& lhs, const PrimExpr& rhs) { |
| 50 | + // quick path without constructing the object |
39 | 51 | if (lhs.same_as(rhs)) return true; |
40 | 52 | if (!lhs.defined() && rhs.defined()) return false; |
41 | 53 | if (!rhs.defined() && lhs.defined()) return false; |
42 | 54 | if (lhs->type_index() != rhs->type_index()) return false; |
43 | | - return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, nullptr, false)) && |
44 | | - !fail_; |
| 55 | + if (auto* plhs = lhs.as<IntImmNode>()) { |
| 56 | + auto* prhs = rhs.as<IntImmNode>(); |
| 57 | + return plhs->dtype == prhs->dtype && plhs->value == prhs->value; |
| 58 | + } |
| 59 | + return ExprDeepEqualChecker().VisitExpr(lhs, rhs); |
45 | 60 | } |
46 | 61 |
|
47 | | - void DeferFail(const ObjectPathPair&) final { fail_ = true; } |
48 | | - bool IsFailDeferralEnabled() final { return false; } |
49 | | - |
50 | | - ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { return lhs; } |
51 | | - void MarkGraphNode() final {} |
| 62 | + bool VisitExpr(const PrimExpr& lhs, const PrimExpr& rhs) final { |
| 63 | + if (lhs.same_as(rhs)) return true; |
| 64 | + if (!lhs.defined() && rhs.defined()) return false; |
| 65 | + if (!rhs.defined() && lhs.defined()) return false; |
| 66 | + if (lhs->type_index() != rhs->type_index()) return false; |
| 67 | + return ExprFunctor::VisitExpr(lhs, rhs); |
| 68 | + } |
52 | 69 |
|
53 | 70 | private: |
54 | | - // reflection vtable |
55 | | - ReflectionVTable* vtable_ = ReflectionVTable::Global(); |
56 | | - bool fail_ = false; |
| 71 | + bool ArrayDeepEqual(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) { |
| 72 | + if (lhs.size() != rhs.size()) return false; |
| 73 | + for (size_t i = 0; i < lhs.size(); i++) { |
| 74 | + if (!VisitExpr(lhs[i], rhs[i])) return false; |
| 75 | + } |
| 76 | + return true; |
| 77 | + } |
| 78 | + |
| 79 | + bool ArrayDeepEqual(const Array<IterVar>& lhs, const Array<IterVar>& rhs) { |
| 80 | + // for iter var, we require pointer equality |
| 81 | + if (lhs.size() != rhs.size()) return false; |
| 82 | + for (size_t i = 0; i < lhs.size(); i++) { |
| 83 | + if (!lhs[i].same_as(rhs[i])) return true; |
| 84 | + } |
| 85 | + return true; |
| 86 | + } |
| 87 | + |
| 88 | + bool OptionalDeepEqual(const Optional<PrimExpr>& lhs, const Optional<PrimExpr>& rhs) { |
| 89 | + if (lhs.same_as(rhs)) return true; |
| 90 | + if (!lhs.defined() && rhs.defined()) return false; |
| 91 | + if (lhs.defined() && !rhs.defined()) return false; |
| 92 | + return VisitExpr(*lhs, *rhs); |
| 93 | + } |
| 94 | + |
| 95 | + bool VisitExpr_(const VarNode* plhs, const PrimExpr& rhs) final { |
| 96 | + // for var, we require pointer equality |
| 97 | + return plhs == rhs.get(); |
| 98 | + } |
| 99 | + |
| 100 | + bool VisitExpr_(const SizeVarNode* plhs, const PrimExpr& rhs) final { |
| 101 | + // for var, we require pointer equality |
| 102 | + return plhs == rhs.get(); |
| 103 | + } |
| 104 | + |
| 105 | + bool VisitExpr_(const BufferLoadNode* plhs, const PrimExpr& rhs) final { |
| 106 | + const auto* prhs = rhs.as<BufferLoadNode>(); |
| 107 | + // we run pointer comparison of the buffer |
| 108 | + return plhs->dtype == prhs->dtype && plhs->buffer.same_as(prhs->buffer) && |
| 109 | + ArrayDeepEqual(plhs->indices, prhs->indices) && |
| 110 | + OptionalDeepEqual(plhs->predicate, prhs->predicate); |
| 111 | + } |
| 112 | + |
| 113 | + bool VisitExpr_(const ProducerLoadNode* plhs, const PrimExpr& rhs) final { |
| 114 | + const auto* prhs = rhs.as<ProducerLoadNode>(); |
| 115 | + // run shallow pointer comparison of the producer |
| 116 | + return plhs->dtype == prhs->dtype && plhs->producer.same_as(prhs->producer) && |
| 117 | + ArrayDeepEqual(plhs->indices, prhs->indices); |
| 118 | + } |
| 119 | + |
| 120 | + bool VisitExpr_(const LetNode* plhs, const PrimExpr& rhs) final { |
| 121 | + const auto* prhs = rhs.as<LetNode>(); |
| 122 | + return plhs->dtype == prhs->dtype && VisitExpr(plhs->var, prhs->var) && |
| 123 | + VisitExpr(plhs->value, prhs->value) && VisitExpr(plhs->body, prhs->body); |
| 124 | + } |
| 125 | + |
| 126 | + bool VisitExpr_(const CallNode* plhs, const PrimExpr& rhs) final { |
| 127 | + const auto* prhs = rhs.as<CallNode>(); |
| 128 | + return plhs->dtype == prhs->dtype && plhs->op.same_as(prhs->op) && |
| 129 | + ArrayDeepEqual(plhs->args, prhs->args); |
| 130 | + } |
| 131 | + |
| 132 | + bool VisitExpr_(const ReduceNode* plhs, const PrimExpr& rhs) final { |
| 133 | + const auto* prhs = rhs.as<ReduceNode>(); |
| 134 | + return plhs->dtype == prhs->dtype && plhs->combiner.same_as(prhs->combiner) && |
| 135 | + ArrayDeepEqual(plhs->source, prhs->source) && ArrayDeepEqual(plhs->init, prhs->init) && |
| 136 | + ArrayDeepEqual(plhs->axis, prhs->axis) && VisitExpr(plhs->condition, prhs->condition) && |
| 137 | + plhs->value_index == prhs->value_index; |
| 138 | + } |
| 139 | + |
| 140 | + bool VisitExpr_(const CastNode* plhs, const PrimExpr& rhs) final { |
| 141 | + const auto* prhs = rhs.as<CastNode>(); |
| 142 | + return plhs->dtype == prhs->dtype && VisitExpr(plhs->value, prhs->value); |
| 143 | + } |
| 144 | + |
| 145 | + bool VisitExpr_(const NotNode* plhs, const PrimExpr& rhs) final { |
| 146 | + const auto* prhs = rhs.as<NotNode>(); |
| 147 | + return plhs->dtype == prhs->dtype && VisitExpr(plhs->a, prhs->a); |
| 148 | + } |
| 149 | + |
| 150 | + bool VisitExpr_(const SelectNode* plhs, const PrimExpr& rhs) final { |
| 151 | + const auto* prhs = rhs.as<SelectNode>(); |
| 152 | + return plhs->dtype == prhs->dtype && VisitExpr(plhs->condition, prhs->condition) && |
| 153 | + VisitExpr(plhs->true_value, prhs->true_value) && |
| 154 | + VisitExpr(plhs->false_value, prhs->false_value); |
| 155 | + } |
| 156 | + |
| 157 | + bool VisitExpr_(const RampNode* plhs, const PrimExpr& rhs) final { |
| 158 | + const auto* prhs = rhs.as<RampNode>(); |
| 159 | + return plhs->dtype == prhs->dtype && VisitExpr(plhs->base, prhs->base) && |
| 160 | + VisitExpr(plhs->stride, prhs->stride) && VisitExpr(plhs->lanes, prhs->lanes); |
| 161 | + } |
| 162 | + |
| 163 | + bool VisitExpr_(const ShuffleNode* plhs, const PrimExpr& rhs) final { |
| 164 | + const auto* prhs = rhs.as<ShuffleNode>(); |
| 165 | + return plhs->dtype == prhs->dtype && ArrayDeepEqual(plhs->vectors, prhs->vectors) && |
| 166 | + ArrayDeepEqual(plhs->indices, prhs->indices); |
| 167 | + } |
| 168 | + |
| 169 | + bool VisitExpr_(const BroadcastNode* plhs, const PrimExpr& rhs) final { |
| 170 | + const auto* prhs = rhs.as<BroadcastNode>(); |
| 171 | + return plhs->dtype == prhs->dtype && VisitExpr(plhs->value, prhs->value) && |
| 172 | + VisitExpr(plhs->lanes, prhs->lanes); |
| 173 | + } |
| 174 | + |
| 175 | + DEFINE_DEEP_EQUAL_BIN_EXPR(AddNode) |
| 176 | + DEFINE_DEEP_EQUAL_BIN_EXPR(SubNode) |
| 177 | + DEFINE_DEEP_EQUAL_BIN_EXPR(MulNode) |
| 178 | + DEFINE_DEEP_EQUAL_BIN_EXPR(DivNode) |
| 179 | + DEFINE_DEEP_EQUAL_BIN_EXPR(ModNode) |
| 180 | + DEFINE_DEEP_EQUAL_BIN_EXPR(FloorDivNode) |
| 181 | + DEFINE_DEEP_EQUAL_BIN_EXPR(FloorModNode) |
| 182 | + DEFINE_DEEP_EQUAL_BIN_EXPR(MinNode) |
| 183 | + DEFINE_DEEP_EQUAL_BIN_EXPR(MaxNode) |
| 184 | + DEFINE_DEEP_EQUAL_BIN_EXPR(EQNode) |
| 185 | + DEFINE_DEEP_EQUAL_BIN_EXPR(NENode) |
| 186 | + DEFINE_DEEP_EQUAL_BIN_EXPR(LTNode) |
| 187 | + DEFINE_DEEP_EQUAL_BIN_EXPR(LENode) |
| 188 | + DEFINE_DEEP_EQUAL_BIN_EXPR(GTNode) |
| 189 | + DEFINE_DEEP_EQUAL_BIN_EXPR(GENode) |
| 190 | + DEFINE_DEEP_EQUAL_BIN_EXPR(AndNode) |
| 191 | + DEFINE_DEEP_EQUAL_BIN_EXPR(OrNode) |
| 192 | + DEFINE_DEEP_EQUAL_IMM_EXPR(IntImmNode) |
| 193 | + DEFINE_DEEP_EQUAL_IMM_EXPR(FloatImmNode) |
| 194 | + DEFINE_DEEP_EQUAL_IMM_EXPR(StringImmNode) |
57 | 195 | }; |
58 | 196 |
|
59 | 197 | bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { |
60 | | - // quick path |
61 | | - if (lhs.same_as(rhs)) return true; |
62 | | - if (!lhs.defined() && rhs.defined()) return false; |
63 | | - if (!rhs.defined() && lhs.defined()) return false; |
64 | | - if (lhs->type_index() != rhs->type_index()) return false; |
65 | | - if (auto* plhs = lhs.as<IntImmNode>()) { |
66 | | - auto* prhs = rhs.as<IntImmNode>(); |
67 | | - return plhs->dtype == prhs->dtype && plhs->value == prhs->value; |
68 | | - } |
69 | | - return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false, std::nullopt); |
| 198 | + return ExprDeepEqualChecker::Check(lhs, rhs); |
70 | 199 | } |
71 | 200 |
|
72 | 201 | TVM_FFI_STATIC_INIT_BLOCK({ |
|
0 commit comments