From 9596f97938427f8516223814915cdcba8021c648 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 25 Jul 2025 10:46:15 -0400 Subject: [PATCH] [FFI][REFACTOR] Migrate StructuralEqual/Hash to new reflection This PR migrates the StructuralEqual/Hash to new reflection based approach. The original mechanisms are still kept around and we will phase them out in followup PRs. The new mechanism unifies the structural equal/hash registration with the normal reflection registeration and also brings cleaner implementation for mismatch detection. --- ffi/include/tvm/ffi/c_api.h | 21 ---- ffi/src/ffi/reflection/structural_equal.cc | 28 +++-- ffi/src/ffi/reflection/structural_hash.cc | 51 +++++--- ffi/tests/cpp/testing_object.h | 11 +- include/tvm/arith/analyzer.h | 2 + include/tvm/arith/int_solver.h | 3 + include/tvm/arith/iter_affine_map.h | 3 + include/tvm/ir/attrs.h | 3 +- include/tvm/ir/diagnostic.h | 2 + include/tvm/ir/env_func.h | 6 +- include/tvm/ir/expr.h | 20 +++- include/tvm/ir/global_info.h | 2 + include/tvm/ir/module.h | 10 ++ include/tvm/ir/op.h | 13 +- include/tvm/ir/source_map.h | 3 + include/tvm/ir/type.h | 8 ++ include/tvm/relax/distributed/struct_info.h | 2 + include/tvm/relax/expr.h | 66 +++++++++-- include/tvm/relax/struct_info.h | 2 +- include/tvm/target/target.h | 2 +- include/tvm/target/target_kind.h | 11 +- include/tvm/te/tensor.h | 1 + include/tvm/tir/buffer.h | 14 ++- include/tvm/tir/expr.h | 9 +- include/tvm/tir/function.h | 11 +- include/tvm/tir/index_map.h | 7 +- include/tvm/tir/stmt.h | 37 +++--- include/tvm/tir/var.h | 6 +- src/contrib/msc/core/ir/graph.h | 4 + src/contrib/msc/core/ir/plugin.h | 4 + src/ir/module.cc | 52 ++++++++ src/ir/type.cc | 7 +- src/meta_schedule/module_equality.cc | 33 ++---- src/node/structural_equal.cc | 111 +++++++++++++++--- src/node/structural_hash.cc | 6 +- src/relax/ir/expr.cc | 28 +++++ src/relax/ir/struct_info.cc | 1 + src/relax/transform/lift_transform_params.cc | 7 +- tests/python/ir/test_node_reflection.py | 9 ++ ...test_tvmscript_printer_structural_equal.py | 10 ++ 40 files changed, 462 insertions(+), 164 deletions(-) diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index e2de610a5df7..60743b82c67e 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -424,27 +424,6 @@ typedef enum { * is only an unique copy of each value. */ kTVMFFISEqHashKindUniqueInstance = 5, - /*! - * \brief provide custom __s_equal__ and __s_hash__ functions through TypeAttrColumn. - * - * The function signatures are(defined via ffi::Function) - * - * \code - * bool __s_equal__( - * ObjectRefType self, ObjectRefType other, - * ffi::TypedFunction cmp, - * ); - * - * uint64_t __s_hash__( - * ObjectRefType self, uint64_t type_key_hash, - * ffi::TypedFunction hash - * ); - * \endcode - * - * Where the extra string field in cmp is the name of the field that is being compared. - * The function should be registered through TVMFFITypeRegisterAttr via reflection::TypeAttrDef. - */ - kTVMFFISEqHashKindCustomTreeNode = 6, #ifdef __cplusplus }; #else diff --git a/ffi/src/ffi/reflection/structural_equal.cc b/ffi/src/ffi/reflection/structural_equal.cc index 03cbdd95bee9..e44a0c3256f3 100644 --- a/ffi/src/ffi/reflection/structural_equal.cc +++ b/ffi/src/ffi/reflection/structural_equal.cc @@ -29,6 +29,7 @@ #include #include +#include #include namespace tvm { @@ -49,7 +50,12 @@ class StructEqualHandler { if (lhs_data->type_index != rhs_data->type_index) { return false; } + if (lhs_data->type_index < TypeIndex::kTVMFFIStaticObjectBegin) { + // specially handle nan for float, as there can be multiple representations of nan + if (lhs_data->type_index == TypeIndex::kTVMFFIFloat && std::isnan(lhs_data->v_float64)) { + return std::isnan(rhs_data->v_float64); + } // this is POD data, we can just compare the value return lhs_data->v_int64 == rhs_data->v_int64; } @@ -90,12 +96,18 @@ class StructEqualHandler { // NOTE: invariant: lhs and rhs are already the same type const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(lhs->type_index()); if (type_info->metadata == nullptr) { - return lhs.same_as(rhs); + TVM_FFI_THROW(TypeError) << "Type metadata is not set for type `" + << String(type_info->type_key) + << "`, so StructuralHash is not supported for this type"; + } + if (type_info->metadata->structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported) { + TVM_FFI_THROW(TypeError) << "_type_s_eq_hash_kind is not set for type `" + << String(type_info->type_key) + << "`, so StructuralHash is not supported for this type"; } - auto structural_eq_hash_kind = type_info->metadata->structural_eq_hash_kind; - if (structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported || - structural_eq_hash_kind == kTVMFFISEqHashKindUniqueInstance) { + auto structural_eq_hash_kind = type_info->metadata->structural_eq_hash_kind; + if (structural_eq_hash_kind == kTVMFFISEqHashKindUniqueInstance) { // use pointer comparison return lhs.same_as(rhs); } @@ -118,8 +130,10 @@ class StructEqualHandler { } } + static reflection::TypeAttrColumn custom_s_equal = reflection::TypeAttrColumn("__s_equal__"); + bool success = true; - if (structural_eq_hash_kind != kTVMFFISEqHashKindCustomTreeNode) { + if (custom_s_equal[type_info->type_index] == nullptr) { // We recursively compare the fields the object ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* field_info) { // skip fields that are marked as structural eq hash ignore @@ -153,7 +167,6 @@ class StructEqualHandler { } }); } else { - static reflection::TypeAttrColumn custom_s_equal = reflection::TypeAttrColumn("__s_equal__"); // run custom equal function defined via __s_equal__ type attribute if (s_equal_callback_ == nullptr) { s_equal_callback_ = ffi::Function::FromTyped( @@ -179,9 +192,6 @@ class StructEqualHandler { return success; }); } - TVM_FFI_ICHECK(custom_s_equal[type_info->type_index] != nullptr) - << "TypeAttr `__s_equal__` is not registered for type `" << String(type_info->type_key) - << "`"; success = custom_s_equal[type_info->type_index] .cast()(lhs, rhs, s_equal_callback_) .cast(); diff --git a/ffi/src/ffi/reflection/structural_hash.cc b/ffi/src/ffi/reflection/structural_hash.cc index ba47de5146d4..e8ffcf6d2a72 100644 --- a/ffi/src/ffi/reflection/structural_hash.cc +++ b/ffi/src/ffi/reflection/structural_hash.cc @@ -30,6 +30,8 @@ #include #include +#include +#include #include #include @@ -48,6 +50,13 @@ class StructuralHashHandler { const TVMFFIAny* src_data = AnyUnsafe::TVMFFIAnyPtrFromAny(src); if (src_data->type_index < TypeIndex::kTVMFFIStaticObjectBegin) { + // specially handle nan for float, as there can be multiple representations of nan + // make sure they map to the same hash value + if (src_data->type_index == TypeIndex::kTVMFFIFloat && std::isnan(src_data->v_float64)) { + TVMFFIAny temp = *src_data; + temp.v_float64 = std::numeric_limits::quiet_NaN(); + return details::StableHashCombine(temp.type_index, temp.v_uint64); + } // this is POD data, we can just hash the value return details::StableHashCombine(src_data->type_index, src_data->v_uint64); } @@ -83,9 +92,16 @@ class StructuralHashHandler { // NOTE: invariant: lhs and rhs are already the same type const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(obj->type_index()); if (type_info->metadata == nullptr) { - // Fallback to pointer hash - return std::hash()(obj.get()); + TVM_FFI_THROW(TypeError) << "Type metadata is not set for type `" + << String(type_info->type_key) + << "`, so StructuralHash is not supported for this type"; } + if (type_info->metadata->structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported) { + TVM_FFI_THROW(TypeError) << "_type_s_eq_hash_kind is not set for type `" + << String(type_info->type_key) + << "`, so StructuralHash is not supported for this type"; + } + auto structural_eq_hash_kind = type_info->metadata->structural_eq_hash_kind; if (structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported) { // Fallback to pointer hash @@ -97,9 +113,11 @@ class StructuralHashHandler { return it->second; } + static reflection::TypeAttrColumn custom_s_hash = reflection::TypeAttrColumn("__s_hash__"); + // compute the hash value uint64_t hash_value = obj->GetTypeKeyHash(); - if (structural_eq_hash_kind != kTVMFFISEqHashKindCustomTreeNode) { + if (custom_s_hash[type_info->type_index] == nullptr) { // go over the content and hash the fields ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { // skip fields that are marked as structural eq hash ignore @@ -119,22 +137,19 @@ class StructuralHashHandler { } }); } else { - static reflection::TypeAttrColumn custom_s_hash = reflection::TypeAttrColumn("__s_hash__"); - TVM_FFI_ICHECK(custom_s_hash[type_info->type_index] != nullptr) - << "TypeAttr `__s_hash__` is not registered for type `" << String(type_info->type_key) - << "`"; if (s_hash_callback_ == nullptr) { - s_hash_callback_ = ffi::Function::FromTyped([this](AnyView val, bool def_region) { - if (def_region) { - bool allow_free_var = true; - std::swap(allow_free_var, map_free_vars_); - uint64_t hash_value = HashAny(val); - std::swap(allow_free_var, map_free_vars_); - return hash_value; - } else { - return HashAny(val); - } - }); + s_hash_callback_ = + ffi::Function::FromTyped([this](AnyView val, uint64_t init_hash, bool def_region) { + if (def_region) { + bool allow_free_var = true; + std::swap(allow_free_var, map_free_vars_); + uint64_t hash_value = HashAny(val); + std::swap(allow_free_var, map_free_vars_); + return details::StableHashCombine(init_hash, hash_value); + } else { + return details::StableHashCombine(init_hash, HashAny(val)); + } + }); } hash_value = custom_s_hash[type_info->type_index] .cast()(obj, hash_value, s_hash_callback_) diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h index 63c2b42d4f77..3d8b4b23ed7c 100644 --- a/ffi/tests/cpp/testing_object.h +++ b/ffi/tests/cpp/testing_object.h @@ -227,10 +227,11 @@ class TCustomFuncObj : public Object { return true; } - uint64_t SHash(uint64_t type_key_hash, ffi::TypedFunction hash) const { - uint64_t hash_value = type_key_hash; - hash_value = tvm::ffi::details::StableHashCombine(hash_value, hash(params, true)); - hash_value = tvm::ffi::details::StableHashCombine(hash_value, hash(body, false)); + uint64_t SHash(uint64_t init_hash, + ffi::TypedFunction hash) const { + uint64_t hash_value = init_hash; + hash_value = hash(params, hash_value, true); + hash_value = hash(body, hash_value, false); return hash_value; } @@ -246,7 +247,7 @@ class TCustomFuncObj : public Object { } static constexpr const char* _type_key = "test.CustomFunc"; - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindCustomTreeNode; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TCustomFuncObj, Object); }; diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 78eac07f4552..54cbab258680 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -106,6 +106,7 @@ class ConstIntBoundNode : public Object { */ static const constexpr int64_t kNegInf = -kPosInf; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "arith.ConstIntBound"; TVM_DECLARE_FINAL_OBJECT_INFO(ConstIntBoundNode, Object); }; @@ -222,6 +223,7 @@ class ModularSetNode : public Object { return equal(coeff, other->coeff) && equal(base, other->base); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "arith.ModularSet"; TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object); }; diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index dd9259cf97cb..e2f384b696ac 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -83,6 +83,7 @@ class IntGroupBoundsNode : public Object { hash_reduce(upper); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const char* _type_key = "arith.IntGroupBounds"; TVM_DECLARE_FINAL_OBJECT_INFO(IntGroupBoundsNode, Object); @@ -173,6 +174,7 @@ class IntConstraintsNode : public Object { hash_reduce(relations); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const char* _type_key = "arith.IntConstraints"; TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object); @@ -238,6 +240,7 @@ class IntConstraintsTransformNode : public Object { hash_reduce(dst_to_src); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const char* _type_key = "arith.IntConstraintsTransform"; TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object); diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index b7f0e09e8323..3c666b430f13 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -116,6 +116,7 @@ class IterMarkNode : public Object { hash_reduce(extent); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; static constexpr const char* _type_key = "arith.IterMark"; @@ -176,6 +177,7 @@ class IterSplitExprNode : public IterMapExprNode { hash_reduce(scale); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "arith.IterSplitExpr"; TVM_DECLARE_FINAL_OBJECT_INFO(IterSplitExprNode, IterMapExprNode); }; @@ -239,6 +241,7 @@ class IterSumExprNode : public IterMapExprNode { hash_reduce(base); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "arith.IterSumExpr"; TVM_DECLARE_FINAL_OBJECT_INFO(IterSumExprNode, IterMapExprNode); }; diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 952cea2a3021..6a43274cae46 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -83,7 +83,7 @@ class AttrFieldInfoNode : public Object { } static constexpr const char* _type_key = "ir.AttrFieldInfo"; - + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr bool _type_has_method_sequal_reduce = false; static constexpr bool _type_has_method_shash_reduce = false; TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object); @@ -122,6 +122,7 @@ class BaseAttrsNode : public Object { TVM_DLL virtual void InitByPackedArgs(const ffi::PackedArgs& kwargs, bool allow_unknown = false) = 0; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; static constexpr const char* _type_key = "ir.Attrs"; diff --git a/include/tvm/ir/diagnostic.h b/include/tvm/ir/diagnostic.h index 8429ac1a6214..e1d7abbead15 100644 --- a/include/tvm/ir/diagnostic.h +++ b/include/tvm/ir/diagnostic.h @@ -79,6 +79,7 @@ class DiagnosticNode : public Object { equal(this->message, other->message); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "Diagnostic"; TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticNode, Object); }; @@ -214,6 +215,7 @@ class DiagnosticContextNode : public Object { return equal(module, other->module) && equal(diagnostics, other->diagnostics); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "DiagnosticContext"; TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticContextNode, Object); }; diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index 03cf3d625ad6..c1fdeb6d1c48 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -51,7 +51,10 @@ class EnvFuncNode : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("name", &EnvFuncNode::name); + // func do not participate in structural equal and hash. + refl::ObjectDef() + .def_ro("name", &EnvFuncNode::name) + .def_ro("func", &EnvFuncNode::func, refl::AttachFieldFlag::SEqHashIgnore()); } bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const { @@ -64,6 +67,7 @@ class EnvFuncNode : public Object { hash_reduce(name); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "ir.EnvFunc"; static constexpr bool _type_has_method_sequal_reduce = true; static constexpr bool _type_has_method_shash_reduce = true; diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 9a8e290cb9bd..cb62cbadf5bb 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -58,11 +58,14 @@ class BaseExprNode : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("span", &BaseExprNode::span, refl::DefaultValue(Span())); + // span do not participate in structural equal and hash. + refl::ObjectDef().def_ro("span", &BaseExprNode::span, refl::DefaultValue(Span()), + refl::AttachFieldFlag::SEqHashIgnore()); } static constexpr const char* _type_key = "ir.BaseExpr"; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; static constexpr const uint32_t _type_child_slots = 64; @@ -428,7 +431,8 @@ class RelaxExprNode : public BaseExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("struct_info_", &RelaxExprNode::struct_info_); + refl::ObjectDef().def_ro("struct_info_", &RelaxExprNode::struct_info_, + refl::AttachFieldFlag::SEqHashIgnore()); } static constexpr const char* _type_key = "ir.RelaxExpr"; @@ -474,6 +478,17 @@ class GlobalVarNode : public RelaxExprNode { hash_reduce.FreeVarHashImpl(this); } + bool SEqual(const GlobalVarNode* other, + ffi::TypedFunction equal) const { + return equal(name_hint, other->name_hint, false, "name_hint"); + } + + uint64_t SHash(uint64_t init_hash, + ffi::TypedFunction hash) const { + return hash(name_hint, init_hash, false); + } + + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar; static constexpr const char* _type_key = "ir.GlobalVar"; TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelaxExprNode); }; @@ -711,6 +726,7 @@ class RangeNode : public Object { } static constexpr const char* _type_key = "ir.Range"; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object); diff --git a/include/tvm/ir/global_info.h b/include/tvm/ir/global_info.h index 4a0b9ffdae25..57eadf2b2992 100644 --- a/include/tvm/ir/global_info.h +++ b/include/tvm/ir/global_info.h @@ -43,6 +43,8 @@ using MemoryScope = String; class GlobalInfoNode : public Object { public: static constexpr const char* _type_key = "ir.GlobalInfo"; + + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_BASE_OBJECT_INFO(GlobalInfoNode, Object); diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index a5c2477b8dd4..66c26b0629ba 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -138,12 +138,21 @@ class IRModuleNode : public Object { .def_ro("source_map", &IRModuleNode::source_map) .def_ro("attrs", &IRModuleNode::attrs) .def_ro("global_infos", &IRModuleNode::global_infos); + // register custom structural equal and hash. + refl::TypeAttrDef() + .def("__s_equal__", &IRModuleNode::SEqual) + .def("__s_hash__", &IRModuleNode::SHash); } TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const; TVM_DLL void SHashReduce(SHashReducer hash_reduce) const; + TVM_DLL bool SEqual(const IRModuleNode* other, + ffi::TypedFunction equal) const; + TVM_DLL uint64_t SHash(uint64_t init_hash, + ffi::TypedFunction hash) const; + /*! * \brief Add a function to the global environment. * \param var The var of the global function. @@ -237,6 +246,7 @@ class IRModuleNode : public Object { TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); static constexpr const char* _type_key = "ir.IRModule"; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object); diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index a50f12b1678d..5903bed8d92e 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -95,12 +95,12 @@ class OpNode : public RelaxExprNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef() .def_ro("name", &OpNode::name) - .def_ro("op_type", &OpNode::op_type) - .def_ro("description", &OpNode::description) - .def_ro("arguments", &OpNode::arguments) - .def_ro("attrs_type_key", &OpNode::attrs_type_key) - .def_ro("num_inputs", &OpNode::num_inputs) - .def_ro("support_level", &OpNode::support_level); + .def_ro("op_type", &OpNode::op_type, refl::AttachFieldFlag::SEqHashIgnore()) + .def_ro("description", &OpNode::description, refl::AttachFieldFlag::SEqHashIgnore()) + .def_ro("arguments", &OpNode::arguments, refl::AttachFieldFlag::SEqHashIgnore()) + .def_ro("attrs_type_key", &OpNode::attrs_type_key, refl::AttachFieldFlag::SEqHashIgnore()) + .def_ro("num_inputs", &OpNode::num_inputs, refl::AttachFieldFlag::SEqHashIgnore()) + .def_ro("support_level", &OpNode::support_level, refl::AttachFieldFlag::SEqHashIgnore()); } bool SEqualReduce(const OpNode* other, SEqualReducer equal) const { @@ -113,6 +113,7 @@ class OpNode : public RelaxExprNode { hash_reduce(name); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindUniqueInstance; static constexpr const char* _type_key = "ir.Op"; TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelaxExprNode); diff --git a/include/tvm/ir/source_map.h b/include/tvm/ir/source_map.h index 87b353278619..d53c234690e2 100644 --- a/include/tvm/ir/source_map.h +++ b/include/tvm/ir/source_map.h @@ -59,6 +59,7 @@ class SourceNameNode : public Object { return equal(name, other->name); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "ir.SourceName"; TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object); }; @@ -118,6 +119,7 @@ class SpanNode : public Object { equal(end_column, other->end_column); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "ir.Span"; TVM_DECLARE_BASE_OBJECT_INFO(SpanNode, Object); }; @@ -233,6 +235,7 @@ class SourceMapObj : public Object { return equal(source_map, other->source_map); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "ir.SourceMap"; TVM_DECLARE_FINAL_OBJECT_INFO(SourceMapObj, Object); }; diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index f879ab59117b..a2ab74a3aeb1 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -80,6 +80,14 @@ class TypeNode : public Object { */ mutable Span span; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + // span do not participate in structural equal and hash. + refl::ObjectDef().def_ro("span", &TypeNode::span, refl::DefaultValue(Span()), + refl::AttachFieldFlag::SEqHashIgnore()); + } + + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "ir.Type"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; diff --git a/include/tvm/relax/distributed/struct_info.h b/include/tvm/relax/distributed/struct_info.h index 3b4a4a0d1d2a..7f843a9f2c75 100644 --- a/include/tvm/relax/distributed/struct_info.h +++ b/include/tvm/relax/distributed/struct_info.h @@ -61,6 +61,7 @@ class PlacementSpecNode : public Object { } static constexpr const char* _type_key = "relax.distributed.PlacementSpec"; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_BASE_OBJECT_INFO(PlacementSpecNode, Object); @@ -119,6 +120,7 @@ class PlacementNode : public Object { static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; static constexpr const char* _type_key = "relax.distributed.Placement"; TVM_DECLARE_FINAL_OBJECT_INFO(PlacementNode, Object); }; diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 34aea7981d76..06aba8618b66 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -57,7 +57,8 @@ class IdNode : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("name_hint", &IdNode::name_hint); + refl::ObjectDef().def_ro("name_hint", &IdNode::name_hint, + refl::AttachFieldFlag::SEqHashIgnore()); } bool SEqualReduce(const IdNode* other, SEqualReducer equal) const { @@ -66,6 +67,7 @@ class IdNode : public Object { void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.FreeVarHashImpl(this); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar; static constexpr const char* _type_key = "relax.Id"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -120,6 +122,13 @@ class StructInfoNode : public Object { */ mutable Span span; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("span", &StructInfoNode::span, + refl::AttachFieldFlag::SEqHashIgnore()); + } + + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "ir.StructInfo"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -397,6 +406,10 @@ class VarNode : public LeafExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("vid", &VarNode::vid); + // customize structural equal and hash to include struct_info_ + refl::TypeAttrDef() + .def("__s_equal__", &VarNode::SEqual) + .def("__s_hash__", &VarNode::SHash); } bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { @@ -409,6 +422,21 @@ class VarNode : public LeafExprNode { hash_reduce(struct_info_); } + bool SEqual(const VarNode* other, + ffi::TypedFunction equal) const { + return equal(vid, other->vid, false, "vid") && + equal(struct_info_, other->struct_info_, false, "struct_info_"); + } + + uint64_t SHash(uint64_t init_hash, + ffi::TypedFunction hash) const { + uint64_t hash_value = init_hash; + hash_value = hash(vid, hash_value, false); + hash_value = hash(struct_info_, hash_value, false); + return hash_value; + } + + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode; static constexpr const char* _type_key = "relax.expr.Var"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -448,6 +476,7 @@ class DataflowVarNode : public VarNode { hash_reduce(struct_info_); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode; static constexpr const char* _type_key = "relax.expr.DataflowVar"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -655,18 +684,19 @@ class DataTypeImm : public LeafExpr { /*! \brief The base class of a variable binding in Relax. */ class BindingNode : public Object { public: + mutable Span span; /*! \brief The return variable to bound to. */ Var var; - mutable Span span; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("var", &BindingNode::var) - .def_ro("span", &BindingNode::span); + .def_ro("span", &BindingNode::span, refl::AttachFieldFlag::SEqHashIgnore()) + .def_ro("var", &BindingNode::var, refl::AttachFieldFlag::SEqHashDef()); } static constexpr const char* _type_key = "relax.expr.Binding"; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_BASE_OBJECT_INFO(BindingNode, Object); @@ -701,9 +731,8 @@ class MatchCastNode : public BindingNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("var", &MatchCastNode::var) .def_ro("value", &MatchCastNode::value) - .def_ro("struct_info", &MatchCastNode::struct_info); + .def_ro("struct_info", &MatchCastNode::struct_info, refl::AttachFieldFlag::SEqHashDef()); } bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const; @@ -734,14 +763,21 @@ class VarBindingNode : public BindingNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("var", &VarBindingNode::var) - .def_ro("value", &VarBindingNode::value); + refl::ObjectDef().def_ro("value", &VarBindingNode::value); + // customize the SEqual and SHash methods for better error messages + refl::TypeAttrDef() + .def("__s_equal__", &VarBindingNode::SEqual) + .def("__s_hash__", &VarBindingNode::SHash); } bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const; void SHashReduce(SHashReducer hash_reduce) const; + bool SEqual(const VarBindingNode* other, + ffi::TypedFunction equal) const; + uint64_t SHash(uint64_t init_hash, + ffi::TypedFunction hash) const; + static constexpr const char* _type_key = "relax.expr.VarBinding"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -757,12 +793,15 @@ class VarBinding : public Binding { class BindingBlockNode : public Object { public: - mutable Span span; Array bindings; + mutable Span span; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("bindings", &BindingBlockNode::bindings); + refl::ObjectDef() + .def_ro("bindings", &BindingBlockNode::bindings) + .def_ro("span", &BindingBlockNode::span, refl::AttachFieldFlag::SEqHashIgnore(), + refl::DefaultValue(Span())); } bool SEqualReduce(const BindingBlockNode* other, SEqualReducer equal) const { @@ -771,6 +810,7 @@ class BindingBlockNode : public Object { void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "relax.expr.BindingBlock"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -906,6 +946,7 @@ class IfNode : public ExprNode { hash_reduce(struct_info_); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode; static constexpr const char* _type_key = "relax.expr.If"; TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode); }; @@ -960,7 +1001,7 @@ class FunctionNode : public BaseFuncNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("params", &FunctionNode::params) + .def_ro("params", &FunctionNode::params, refl::AttachFieldFlag::SEqHashDef()) .def_ro("body", &FunctionNode::body) .def_ro("ret_struct_info", &FunctionNode::ret_struct_info) .def_ro("is_pure", &FunctionNode::is_pure); @@ -983,6 +1024,7 @@ class FunctionNode : public BaseFuncNode { hash_reduce(struct_info_); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode; static constexpr const char* _type_key = "relax.expr.Function"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h index 25a6b1ef4afc..cd9b05ab29f0 100644 --- a/include/tvm/relax/struct_info.h +++ b/include/tvm/relax/struct_info.h @@ -341,7 +341,7 @@ class FuncStructInfoNode : public StructInfoNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("params", &FuncStructInfoNode::params) + .def_ro("params", &FuncStructInfoNode::params, refl::AttachFieldFlag::SEqHashDef()) .def_ro("ret", &FuncStructInfoNode::ret) .def_ro("derive_func", &FuncStructInfoNode::derive_func) .def_ro("purity", &FuncStructInfoNode::purity); diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 98c984c1575f..9929791f31d3 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -177,7 +177,7 @@ class TargetNode : public Object { void SHashReduce(SHashReducer hash_reduce) const; static constexpr const char* _type_key = "target.Target"; - + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object); diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index c9785820b40e..15b5f62cd566 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -81,12 +81,14 @@ class TargetKindNode : public Object { namespace refl = tvm::ffi::reflection; refl::ObjectDef() .def_ro("name", &TargetKindNode::name) - .def_ro("default_device_type", &TargetKindNode::default_device_type) - .def_ro("default_keys", &TargetKindNode::default_keys); + .def_ro("default_device_type", &TargetKindNode::default_device_type, + refl::AttachFieldFlag::SEqHashIgnore()) + .def_ro("default_keys", &TargetKindNode::default_keys, + refl::AttachFieldFlag::SEqHashIgnore()); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindUniqueInstance; static constexpr const char* _type_key = "target.TargetKind"; - TVM_DECLARE_FINAL_OBJECT_INFO(TargetKindNode, Object); private: @@ -134,10 +136,11 @@ class TargetKind : public ObjectRef { * \return The TargetKind requested */ TVM_DLL static Optional Get(const String& target_kind_name); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TargetKind, ObjectRef, TargetKindNode); /*! \brief Mutable access to the container class */ TargetKindNode* operator->() { return static_cast(data_.get()); } + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TargetKind, ObjectRef, TargetKindNode); + private: TVM_DLL static const AttrRegistryMapContainerMap& GetAttrMapContainer( const String& attr_name); diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index 15be5bb069e8..f45a96df63d8 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -88,6 +88,7 @@ class TensorNode : public DataProducerNode { TVM_DLL String GetNameHint() const final; static constexpr const char* _type_key = "te.Tensor"; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, DataProducerNode); }; diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 6b49f619b5e8..cb16d2912aa0 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -115,13 +115,14 @@ class BufferNode : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("data", &BufferNode::data) + .def_ro("data", &BufferNode::data, refl::AttachFieldFlag::SEqHashDef()) .def_ro("dtype", &BufferNode::dtype) - .def_ro("shape", &BufferNode::shape) - .def_ro("strides", &BufferNode::strides) - .def_ro("axis_separators", &BufferNode::axis_separators) - .def_ro("elem_offset", &BufferNode::elem_offset) - .def_ro("name", &BufferNode::name) + .def_ro("shape", &BufferNode::shape, refl::AttachFieldFlag::SEqHashDef()) + .def_ro("strides", &BufferNode::strides, refl::AttachFieldFlag::SEqHashDef()) + .def_ro("axis_separators", &BufferNode::axis_separators, + refl::AttachFieldFlag::SEqHashDef()) + .def_ro("elem_offset", &BufferNode::elem_offset, refl::AttachFieldFlag::SEqHashDef()) + .def_ro("name", &BufferNode::name, refl::AttachFieldFlag::SEqHashIgnore()) .def_ro("data_alignment", &BufferNode::data_alignment) .def_ro("offset_factor", &BufferNode::offset_factor) .def_ro("buffer_type", &BufferNode::buffer_type) @@ -163,6 +164,7 @@ class BufferNode : public Object { Array ElemOffset(Array index) const; static constexpr const char* _type_key = "tir.Buffer"; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object); diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 9525f88784ff..3e6a07a6cd6b 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -833,7 +833,7 @@ class LetNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("var", &LetNode::var) + .def_ro("var", &LetNode::var, refl::AttachFieldFlag::SEqHashDef()) .def_ro("value", &LetNode::value) .def_ro("body", &LetNode::body); } @@ -989,11 +989,11 @@ class CommReducerNode : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("lhs", &CommReducerNode::lhs) - .def_ro("rhs", &CommReducerNode::rhs) + .def_ro("lhs", &CommReducerNode::lhs, refl::AttachFieldFlag::SEqHashDef()) + .def_ro("rhs", &CommReducerNode::rhs, refl::AttachFieldFlag::SEqHashDef()) .def_ro("result", &CommReducerNode::result) .def_ro("identity_element", &CommReducerNode::identity_element) - .def_ro("span", &CommReducerNode::span); + .def_ro("span", &CommReducerNode::span, refl::AttachFieldFlag::SEqHashIgnore()); } bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const { @@ -1009,6 +1009,7 @@ class CommReducerNode : public Object { } static constexpr const char* _type_key = "tir.CommReducer"; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object); diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index ead03e967663..2671f9879101 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -49,8 +49,6 @@ class PrimFuncNode : public BaseFuncNode { public: /*! \brief Function parameters */ Array params; - /*! \brief The body of the function */ - tir::Stmt body; /*! \brief The return type of the function. */ Type ret_type; /*! @@ -99,14 +97,16 @@ class PrimFuncNode : public BaseFuncNode { * flattened alias of the buffer. */ Map buffer_map; + /*! \brief The body of the function */ + tir::Stmt body; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("params", &PrimFuncNode::params) - .def_ro("body", &PrimFuncNode::body) + .def_ro("params", &PrimFuncNode::params, refl::AttachFieldFlag::SEqHashDef()) .def_ro("ret_type", &PrimFuncNode::ret_type) - .def_ro("buffer_map", &PrimFuncNode::buffer_map); + .def_ro("buffer_map", &PrimFuncNode::buffer_map) + .def_ro("body", &PrimFuncNode::body); } bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const { @@ -123,6 +123,7 @@ class PrimFuncNode : public BaseFuncNode { hash_reduce(body); hash_reduce(attrs); } + /*! * \brief Return the derived function annotation of this function. * diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h index 1cc6fa950ae8..55d083834dc9 100644 --- a/include/tvm/tir/index_map.h +++ b/include/tvm/tir/index_map.h @@ -154,9 +154,11 @@ class IndexMapNode : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("initial_indices", &IndexMapNode::initial_indices) + .def_ro("initial_indices", &IndexMapNode::initial_indices, + refl::AttachFieldFlag::SEqHashDef()) .def_ro("final_indices", &IndexMapNode::final_indices) - .def_ro("inverse_index_map", &IndexMapNode::inverse_index_map); + .def_ro("inverse_index_map", &IndexMapNode::inverse_index_map, + refl::AttachFieldFlag::SEqHashIgnore()); } bool SEqualReduce(const IndexMapNode* other, SEqualReducer equal) const { @@ -169,6 +171,7 @@ class IndexMapNode : public Object { hash_reduce(final_indices); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "tir.IndexMap"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index b89fff003215..9d31d25c398d 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -54,6 +54,7 @@ class StmtNode : public Object { TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); static constexpr const char* _type_key = "tir.Stmt"; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; static constexpr const uint32_t _type_child_slots = 15; @@ -81,7 +82,7 @@ class LetStmtNode : public StmtNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("var", &LetStmtNode::var) + .def_ro("var", &LetStmtNode::var, refl::AttachFieldFlag::SEqHashDef()) .def_ro("value", &LetStmtNode::value) .def_ro("body", &LetStmtNode::body); } @@ -371,7 +372,7 @@ class AllocateNode : public StmtNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("buffer_var", &AllocateNode::buffer_var) + .def_ro("buffer_var", &AllocateNode::buffer_var, refl::AttachFieldFlag::SEqHashDef()) .def_ro("dtype", &AllocateNode::dtype) .def_ro("extents", &AllocateNode::extents) .def_ro("condition", &AllocateNode::condition) @@ -460,7 +461,7 @@ class AllocateConstNode : public StmtNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("buffer_var", &AllocateConstNode::buffer_var) + .def_ro("buffer_var", &AllocateConstNode::buffer_var, refl::AttachFieldFlag::SEqHashDef()) .def_ro("data", &AllocateConstNode::data) .def_ro("irmod_storage_idx", &AllocateConstNode::irmod_storage_idx) .def_ro("dtype", &AllocateConstNode::dtype) @@ -896,7 +897,7 @@ class ForNode : public StmtNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("loop_var", &ForNode::loop_var) + .def_ro("loop_var", &ForNode::loop_var, refl::AttachFieldFlag::SEqHashDef()) .def_ro("min", &ForNode::min) .def_ro("extent", &ForNode::extent) .def_ro("kind", &ForNode::kind) @@ -1017,6 +1018,7 @@ class BufferRegionNode : public PrimExprConvertibleNode { TVM_DLL PrimExpr ToPrimExpr() const final; static constexpr const char* _type_key = "tir.BufferRegion"; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(BufferRegionNode, PrimExprConvertibleNode); @@ -1082,6 +1084,7 @@ class MatchBufferRegionNode : public Object { } static constexpr const char* _type_key = "tir.MatchBufferRegion"; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(MatchBufferRegionNode, Object); @@ -1130,8 +1133,12 @@ class BlockNode : public StmtNode { Array writes; /*! \brief The name_hint of the block. */ String name_hint; - /*! \brief The body of the block. */ - Stmt body; + /*! \brief The buffer allocated in the block. */ + Array alloc_buffers; + /*! \brief The match buffer regions. */ + Array match_buffers; + /*! \brief The annotation of the block. */ + Map annotations; /*! * \brief The init statement is executed during the first iteration of reduction loops in a * reduction block. The optional init field allows us to represent initialization and @@ -1140,25 +1147,21 @@ class BlockNode : public StmtNode { * Init field is `std::nullopt` if there is no reduction iter_vars */ Optional init; - /*! \brief The buffer allocated in the block. */ - Array alloc_buffers; - /*! \brief The match buffer regions. */ - Array match_buffers; - /*! \brief The annotation of the block. */ - Map annotations; + /*! \brief The body of the block. */ + Stmt body; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("iter_vars", &BlockNode::iter_vars) + .def_ro("iter_vars", &BlockNode::iter_vars, refl::AttachFieldFlag::SEqHashDef()) .def_ro("reads", &BlockNode::reads) .def_ro("writes", &BlockNode::writes) - .def_ro("name_hint", &BlockNode::name_hint) - .def_ro("body", &BlockNode::body) - .def_ro("init", &BlockNode::init) + .def_ro("name_hint", &BlockNode::name_hint, refl::AttachFieldFlag::SEqHashIgnore()) .def_ro("alloc_buffers", &BlockNode::alloc_buffers) .def_ro("match_buffers", &BlockNode::match_buffers) - .def_ro("annotations", &BlockNode::annotations); + .def_ro("annotations", &BlockNode::annotations) + .def_ro("init", &BlockNode::init) + .def_ro("body", &BlockNode::body); } bool SEqualReduce(const BlockNode* other, SEqualReducer equal) const { diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 24c7e6944d04..021b6c301a68 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -64,7 +64,7 @@ class VarNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("name", &VarNode::name_hint) + .def_ro("name", &VarNode::name_hint, refl::AttachFieldFlag::SEqHashIgnore()) .def_ro("type_annotation", &VarNode::type_annotation); } @@ -80,6 +80,7 @@ class VarNode : public PrimExprNode { hash_reduce.FreeVarHashImpl(this); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar; static constexpr const char* _type_key = "tir.Var"; static constexpr const uint32_t _type_child_slots = 1; TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode); @@ -290,7 +291,7 @@ class IterVarNode : public PrimExprConvertibleNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef() .def_ro("dom", &IterVarNode::dom) - .def_ro("var", &IterVarNode::var) + .def_ro("var", &IterVarNode::var, refl::AttachFieldFlag::SEqHashDef()) .def_ro("iter_type", &IterVarNode::iter_type) .def_ro("thread_tag", &IterVarNode::thread_tag); } @@ -308,6 +309,7 @@ class IterVarNode : public PrimExprConvertibleNode { } static constexpr const char* _type_key = "tir.IterVar"; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, PrimExprConvertibleNode); diff --git a/src/contrib/msc/core/ir/graph.h b/src/contrib/msc/core/ir/graph.h index 2ee4aeb05b3f..1d6b1046d9e1 100644 --- a/src/contrib/msc/core/ir/graph.h +++ b/src/contrib/msc/core/ir/graph.h @@ -400,6 +400,7 @@ class MSCTensorNode : public Object { hash_reduce(prims); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "msc.core.MSCTensor"; TVM_DECLARE_FINAL_OBJECT_INFO(MSCTensorNode, Object); }; @@ -514,6 +515,7 @@ class BaseJointNode : public Object { hash_reduce(children); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "msc.core.BaseJoint"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -600,6 +602,7 @@ class MSCJointNode : public BaseJointNode { hash_reduce(weights); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "msc.core.MSCJoint"; TVM_DECLARE_FINAL_OBJECT_INFO(MSCJointNode, BaseJointNode); }; @@ -833,6 +836,7 @@ class BaseGraphNode : public Object { hash_reduce(node_names); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "msc.core.BaseGraph"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; diff --git a/src/contrib/msc/core/ir/plugin.h b/src/contrib/msc/core/ir/plugin.h index b926edf8e5af..291a0e196a24 100644 --- a/src/contrib/msc/core/ir/plugin.h +++ b/src/contrib/msc/core/ir/plugin.h @@ -290,6 +290,7 @@ class PluginAttrNode : public Object { hash_reduce(describe); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "msc.core.PluginAttr"; TVM_DECLARE_FINAL_OBJECT_INFO(PluginAttrNode, Object); }; @@ -371,6 +372,7 @@ class PluginTensorNode : public Object { hash_reduce(describe); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "msc.core.PluginTensor"; TVM_DECLARE_FINAL_OBJECT_INFO(PluginTensorNode, Object); }; @@ -454,6 +456,7 @@ class PluginExternNode : public Object { hash_reduce(describe); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "msc.core.PluginExtern"; TVM_DECLARE_FINAL_OBJECT_INFO(PluginExternNode, Object); }; @@ -565,6 +568,7 @@ class PluginNode : public Object { hash_reduce(options); } + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "msc.core.Plugin"; TVM_DECLARE_FINAL_OBJECT_INFO(PluginNode, Object); }; diff --git a/src/ir/module.cc b/src/ir/module.cc index 9eedbd5e303f..f17874724676 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -96,6 +96,30 @@ bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) return true; } +bool IRModuleNode::SEqual(const IRModuleNode* other, + ffi::TypedFunction equal) const { + if (!equal(this->attrs, other->attrs, false, "attrs")) { + return false; + } + if (!equal(this->global_infos, other->global_infos, false, "global_infos")) { + return false; + } + + // Define remaps for GlobalVar and GlobalTypeVar based on their string name. + for (const auto& gv : this->GetGlobalVars()) { + if (other->ContainGlobalVar(gv->name_hint)) { + if (!equal(gv, other->GetGlobalVar(gv->name_hint), true, "functions")) return false; + } + } + + // now check the functions with the GlobalVar remappped + if (!equal(this->functions, other->functions, false, "functions")) { + return false; + } + + return true; +} + void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const { using KV = std::tuple; // hash the functions. @@ -127,6 +151,34 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const { hash_reduce(this->global_infos); } +uint64_t IRModuleNode::SHash(uint64_t init_hash, + ffi::TypedFunction hash) const { + uint64_t hash_value = init_hash; + hash_value = hash(this->attrs, hash_value, false); + hash_value = hash(this->global_infos, hash_value, false); + + // hash the functions. + using KV = std::tuple; + std::vector temp; + for (const auto& kv : this->functions) { + temp.emplace_back(kv.first->name_hint, kv.first, kv.second); + } + // sort by the hash key of the keys. + std::sort(temp.begin(), temp.end(), + [](const KV& lhs, const KV& rhs) { return std::get<0>(lhs) < std::get<0>(rhs); }); + hash_value = hash(static_cast(temp.size()), hash_value, false); + // first need to define the GlobalVar in the order of the keys + for (size_t i = 0; i < temp.size(); ++i) { + hash_value = hash(std::get<1>(temp[i]), hash_value, true); + } + // hash the name and content + for (size_t i = 0; i < temp.size(); ++i) { + hash_value = hash(std::get<0>(temp[i]), hash_value, false); + hash_value = hash(std::get<2>(temp[i]), hash_value, false); + } + return hash_value; +} + bool IRModuleNode::ContainGlobalVar(const String& name) const { return global_var_map_.find(name) != global_var_map_.end(); } diff --git a/src/ir/type.cc b/src/ir/type.cc index 4e580356ff7c..37b251f1f949 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -27,6 +27,7 @@ namespace tvm { TVM_FFI_STATIC_INIT_BLOCK({ + TypeNode::RegisterReflection(); PrimTypeNode::RegisterReflection(); PointerTypeNode::RegisterReflection(); TupleTypeNode::RegisterReflection(); @@ -50,8 +51,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ PointerType::PointerType(Type element_type, String storage_scope) { ObjectPtr n = make_object(); + if (storage_scope.empty()) { + n->storage_scope = "global"; + } else { + n->storage_scope = std::move(storage_scope); + } n->element_type = std::move(element_type); - n->storage_scope = std::move(storage_scope); data_ = std::move(n); } diff --git a/src/meta_schedule/module_equality.cc b/src/meta_schedule/module_equality.cc index 986233ca490f..8901d5fd8d57 100644 --- a/src/meta_schedule/module_equality.cc +++ b/src/meta_schedule/module_equality.cc @@ -18,6 +18,8 @@ */ #include "module_equality.h" +#include +#include #include #include #include @@ -37,28 +39,15 @@ class ModuleEqualityStructural : public ModuleEquality { String GetName() const { return "structural"; } }; -class SEqualHandlerIgnoreNDArray : public SEqualHandlerDefault { - public: - SEqualHandlerIgnoreNDArray() : SEqualHandlerDefault(false, nullptr, false) {} - - protected: - bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, - const Optional& current_paths) { - if (auto lhs_ptr = lhs.as(), - rhs_ptr = rhs.as(); - lhs_ptr && rhs_ptr) { - SEqualReducer reducer(this, nullptr, map_free_vars); - return NDArrayEqual(lhs_ptr, rhs_ptr, reducer, false); - } - return SEqualHandlerDefault::DispatchSEqualReduce(lhs, rhs, map_free_vars, current_paths); - } -}; - class ModuleEqualityIgnoreNDArray : public ModuleEquality { public: - size_t Hash(IRModule mod) const { return SHashHandlerIgnoreNDArray().Hash(mod, false); } + size_t Hash(IRModule mod) const { + return tvm::ffi::reflection::StructuralHash::Hash(mod, /*map_free_vars=*/false, + /*skip_ndarray_content=*/true); + } bool Equal(IRModule lhs, IRModule rhs) const { - return SEqualHandlerIgnoreNDArray().Equal(lhs, rhs, false); + return tvm::ffi::reflection::StructuralEqual::Equal(lhs, rhs, /*map_free_vars=*/false, + /*skip_ndarray_content=*/true); } String GetName() const { return "ignore-ndarray"; } }; @@ -77,8 +66,10 @@ class ModuleEqualityAnchorBlock : public ModuleEquality { auto anchor_block_lhs = tir::FindAnchorBlock(lhs); auto anchor_block_rhs = tir::FindAnchorBlock(rhs); if (anchor_block_lhs && anchor_block_rhs) { - return SEqualHandlerIgnoreNDArray().Equal(GetRef(anchor_block_lhs), - GetRef(anchor_block_rhs), false); + return tvm::ffi::reflection::StructuralEqual::Equal(GetRef(anchor_block_lhs), + GetRef(anchor_block_rhs), + /*map_free_vars=*/false, + /*skip_ndarray_content=*/true); } return ModuleEqualityIgnoreNDArray().Equal(lhs, rhs); } diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 43dee2eb3b64..5987692a0f78 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -20,7 +20,9 @@ * \file src/node/structural_equal.cc */ #include +#include #include +#include #include #include #include @@ -599,34 +601,107 @@ bool SEqualHandlerDefault::DispatchSEqualReduce(const ObjectRef& lhs, const Obje return impl->DispatchSEqualReduce(lhs, rhs, map_free_vars, current_paths); } +Optional ObjectPathPairFromAccessPathPair( + Optional src) { + if (!src.has_value()) return std::nullopt; + auto translate_path = [](ffi::reflection::AccessPath path) { + ObjectPath result = ObjectPath::Root(); + for (const auto& step : path) { + switch (step->kind) { + case ffi::reflection::AccessKind::kObjectField: { + result = result->Attr(step->key.cast()); + break; + } + case ffi::reflection::AccessKind::kArrayIndex: { + result = result->ArrayIndex(step->key.cast()); + break; + } + case ffi::reflection::AccessKind::kMapKey: { + result = result->MapValue(step->key); + break; + } + case ffi::reflection::AccessKind::kArrayIndexMissing: { + result = result->MissingArrayElement(step->key.cast()); + break; + } + case ffi::reflection::AccessKind::kMapKeyMissing: { + result = result->MissingMapEntry(); + break; + } + default: { + LOG(FATAL) << "Invalid access path kind: " << static_cast(step->kind); + break; + } + } + } + return result; + }; + + return ObjectPathPair(translate_path((*src).get<0>()), translate_path((*src).get<1>())); +} + +bool NodeStructuralEqualAdapter(const Any& lhs, const Any& rhs, bool assert_mode, + bool map_free_vars) { + if (assert_mode) { + auto first_mismatch = ObjectPathPairFromAccessPathPair( + ffi::reflection::StructuralEqual::GetFirstMismatch(lhs, rhs, map_free_vars)); + if (first_mismatch.has_value()) { + std::ostringstream oss; + oss << "StructuralEqual check failed, caused by lhs"; + oss << " at " << (*first_mismatch)->lhs_path; + { + // print lhs + PrinterConfig cfg; + cfg->syntax_sugar = false; + cfg->path_to_underline.push_back((*first_mismatch)->lhs_path); + // The TVMScriptPrinter::Script will fallback to Repr printer, + // if the root node to print is not supported yet, + // e.g. Relax nodes, ArrayObj, MapObj, etc. + oss << ":" << std::endl << TVMScriptPrinter::Script(lhs.cast(), cfg); + } + oss << std::endl << "and rhs"; + { + // print rhs + oss << " at " << (*first_mismatch)->rhs_path; + { + PrinterConfig cfg; + cfg->syntax_sugar = false; + cfg->path_to_underline.push_back((*first_mismatch)->rhs_path); + // The TVMScriptPrinter::Script will fallback to Repr printer, + // if the root node to print is not supported yet, + // e.g. Relax nodes, ArrayObj, MapObj, etc. + oss << ":" << std::endl << TVMScriptPrinter::Script(rhs.cast(), cfg); + } + } + TVM_FFI_THROW(ValueError) << oss.str(); + } + return true; + } else { + return ffi::reflection::StructuralEqual::Equal(lhs, rhs, map_free_vars); + } +} + TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("node.StructuralEqual", - [](const Any& lhs, const Any& rhs, bool assert_mode, bool map_free_vars) { - // If we are asserting on failure, then the `defer_fails` option - // should be enabled, to provide better error messages. For - // example, if the number of bindings in a `relax::BindingBlock` - // differs, highlighting the first difference rather than the - // entire block. - bool defer_fails = assert_mode; - Optional first_mismatch; - return SEqualHandlerDefault(assert_mode, &first_mismatch, defer_fails) - .Equal(lhs, rhs, map_free_vars); - }) + .def("node.StructuralEqual", NodeStructuralEqualAdapter) .def("node.GetFirstStructuralMismatch", [](const Any& lhs, const Any& rhs, bool map_free_vars) { - Optional first_mismatch; - bool equal = - SEqualHandlerDefault(false, &first_mismatch, true).Equal(lhs, rhs, map_free_vars); - ICHECK(equal == !first_mismatch.defined()); - return first_mismatch; + /* + Optional first_mismatch; + bool equal = + SEqualHandlerDefault(false, &first_mismatch, true).Equal(lhs, rhs, map_free_vars); + ICHECK(equal == !first_mismatch.defined()); + return first_mismatch; + */ + return ObjectPathPairFromAccessPathPair( + ffi::reflection::StructuralEqual::GetFirstMismatch(lhs, rhs, map_free_vars)); }); }); bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_params) const { - return SEqualHandlerDefault(false, nullptr, false).Equal(lhs, rhs, map_free_params); + return ffi::reflection::StructuralEqual::Equal(lhs, rhs, map_free_params); } bool NDArrayEqual(const runtime::NDArray::Container* lhs, const runtime::NDArray::Container* rhs, diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 2a60754f7b44..6fb8d3678454 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -296,13 +297,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("node.StructuralHash", [](const Any& object, bool map_free_vars) -> int64_t { - uint64_t hashed_value = SHashHandlerDefault().Hash(object, map_free_vars); - return static_cast(hashed_value); + return ffi::reflection::StructuralHash::Hash(object, map_free_vars); }); }); uint64_t StructuralHash::operator()(const ObjectRef& object) const { - return SHashHandlerDefault().Hash(object, false); + return ffi::reflection::StructuralHash::Hash(object, false); } void SHashHandlerIgnoreNDArray::DispatchSHash(const ObjectRef& object, bool map_free_vars) { diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 6005497a8ee3..c905b8730571 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -517,6 +517,34 @@ void VarBindingNode::SHashReduce(SHashReducer hash_reduce) const { } } +bool VarBindingNode::SEqual(const VarBindingNode* other, + ffi::TypedFunction equal) const { + if (value->IsInstance()) { + // Recursive function definitions may reference the bound variable + // within the value being bound. In these cases, the + // var comparison must occur first to define the var, to ensure it is + // defined at point of use. + return equal(var, other->var, true, "var") && equal(value, other->value, false, "value"); + } else { + // In all other cases, visit the bound value before the variable + // it is bound to, in order to provide better error messages. + return equal(value, other->value, false, "value") && equal(var, other->var, true, "var"); + } +} + +uint64_t VarBindingNode::SHash(uint64_t init_hash, + ffi::TypedFunction hash) const { + uint64_t hash_value = init_hash; + if (value->IsInstance()) { + hash_value = hash(var, hash_value, true); + hash_value = hash(value, hash_value, false); + } else { + hash_value = hash(value, hash_value, false); + hash_value = hash(var, hash_value, true); + } + return hash_value; +} + TVM_REGISTER_NODE_TYPE(BindingBlockNode); BindingBlock::BindingBlock(Array bindings, Span span) { diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc index 00414fe33930..499dd47bb723 100644 --- a/src/relax/ir/struct_info.cc +++ b/src/relax/ir/struct_info.cc @@ -31,6 +31,7 @@ namespace tvm { namespace relax { TVM_FFI_STATIC_INIT_BLOCK({ + StructInfoNode::RegisterReflection(); ObjectStructInfoNode::RegisterReflection(); PrimStructInfoNode::RegisterReflection(); ShapeStructInfoNode::RegisterReflection(); diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 1a426ec5da27..83d978f27d0c 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -540,8 +541,8 @@ class ParamRemapper : private ExprFunctor { } else { var_remap_.Set(GetRef(lhs_var), rhs_var); } - CHECK(structural_equal.Equal(lhs_var->struct_info_, rhs_var->struct_info_, - /*map_free_vars=*/true)) + CHECK(tvm::ffi::reflection::StructuralEqual::Equal(lhs_var->struct_info_, rhs_var->struct_info_, + /*map_free_vars=*/true)) << "The struct info of the parameters should be the same for all target functions"; auto lhs_tir_vars = DefinableTIRVarsInStructInfo(GetStructInfo(GetRef(lhs_var))); auto rhs_tir_vars = DefinableTIRVarsInStructInfo(GetStructInfo(rhs_expr)); @@ -555,8 +556,6 @@ class ParamRemapper : private ExprFunctor { } } - SEqualHandlerDefault structural_equal{/*assert_mode=*/false, /*first_mismatch=*/nullptr, - /*defer_fail=*/false}; Map var_remap_; Map tir_var_remap_; }; diff --git a/tests/python/ir/test_node_reflection.py b/tests/python/ir/test_node_reflection.py index 741e61b2eb48..be00bc3a4777 100644 --- a/tests/python/ir/test_node_reflection.py +++ b/tests/python/ir/test_node_reflection.py @@ -181,6 +181,15 @@ def test_ndarray_dict(): tvm.ir.assert_structural_equal(m1, m2) +def test_free_var_equal(): + x = tvm.tir.Var("x", dtype="int32") + y = tvm.tir.Var("y", dtype="int32") + z = tvm.tir.Var("z", dtype="int32") + v1 = x + y + v1 = y + z + tvm.ir.assert_structural_equal(x, z, map_free_vars=True) + + def test_alloc_const(): dev = tvm.cpu(0) dtype = "float32" diff --git a/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py b/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py index 58d9402e6fde..bbf95801ed0a 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py +++ b/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py @@ -45,6 +45,9 @@ def func2(a: T.handle, b: T.handle): A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 256)) + func1 = func1.with_attr("global_symbol", "main") + func2 = func2.with_attr("global_symbol", "main") + with pytest.raises(ValueError) as ve: assert_structural_equal(func1, func2) assert _error_message(ve.value) == _expected_result( @@ -109,8 +112,12 @@ def func2(): a_data = T.allocate((256, 128), dtype="float32") a = T.decl_buffer((256, 128), dtype="float32", data=a_data) + func1 = func1.with_attr("global_symbol", "main") + func2 = func2.with_attr("global_symbol", "main") + with pytest.raises(ValueError) as ve: assert_structural_equal(func1, func2) + assert _error_message(ve.value) == _expected_result( func1, func2, @@ -132,6 +139,9 @@ def func2(): with T.block(): pass + func1 = func1.with_attr("global_symbol", "main") + func2 = func2.with_attr("global_symbol", "main") + with pytest.raises(ValueError) as ve: assert_structural_equal(func1, func2) assert _error_message(ve.value) == _expected_result(