Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions ffi/include/tvm/ffi/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,27 +424,6 @@ typedef enum {
* is only an unique copy of each value.
*/
kTVMFFISEqHashKindUniqueInstance = 5,
/*!
* \brief provide custom __s_equal__ and __s_hash__ functions through TypeAttrColumn.
*
* The function signatures are(defined via ffi::Function)
*
* \code
* bool __s_equal__(
* ObjectRefType self, ObjectRefType other,
* ffi::TypedFunction<bool(AnyView, AnyView, bool def_region, string field_name)> cmp,
* );
*
* uint64_t __s_hash__(
* ObjectRefType self, uint64_t type_key_hash,
* ffi::TypedFunction<uint64_t(AnyView, bool def_region)> hash
* );
* \endcode
*
* Where the extra string field in cmp is the name of the field that is being compared.
* The function should be registered through TVMFFITypeRegisterAttr via reflection::TypeAttrDef.
*/
kTVMFFISEqHashKindCustomTreeNode = 6,
#ifdef __cplusplus
};
#else
Expand Down
28 changes: 19 additions & 9 deletions ffi/src/ffi/reflection/structural_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/ffi/reflection/structural_equal.h>
#include <tvm/ffi/string.h>

#include <cmath>
#include <unordered_map>

namespace tvm {
Expand All @@ -49,7 +50,12 @@ class StructEqualHandler {
if (lhs_data->type_index != rhs_data->type_index) {
return false;
}

if (lhs_data->type_index < TypeIndex::kTVMFFIStaticObjectBegin) {
// specially handle nan for float, as there can be multiple representations of nan
if (lhs_data->type_index == TypeIndex::kTVMFFIFloat && std::isnan(lhs_data->v_float64)) {
return std::isnan(rhs_data->v_float64);
}
// this is POD data, we can just compare the value
return lhs_data->v_int64 == rhs_data->v_int64;
}
Expand Down Expand Up @@ -90,12 +96,18 @@ class StructEqualHandler {
// NOTE: invariant: lhs and rhs are already the same type
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(lhs->type_index());
if (type_info->metadata == nullptr) {
return lhs.same_as(rhs);
TVM_FFI_THROW(TypeError) << "Type metadata is not set for type `"
<< String(type_info->type_key)
<< "`, so StructuralHash is not supported for this type";
}
if (type_info->metadata->structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported) {
TVM_FFI_THROW(TypeError) << "_type_s_eq_hash_kind is not set for type `"
<< String(type_info->type_key)
<< "`, so StructuralHash is not supported for this type";
}
auto structural_eq_hash_kind = type_info->metadata->structural_eq_hash_kind;

if (structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported ||
structural_eq_hash_kind == kTVMFFISEqHashKindUniqueInstance) {
auto structural_eq_hash_kind = type_info->metadata->structural_eq_hash_kind;
if (structural_eq_hash_kind == kTVMFFISEqHashKindUniqueInstance) {
// use pointer comparison
return lhs.same_as(rhs);
}
Expand All @@ -118,8 +130,10 @@ class StructEqualHandler {
}
}

static reflection::TypeAttrColumn custom_s_equal = reflection::TypeAttrColumn("__s_equal__");

bool success = true;
if (structural_eq_hash_kind != kTVMFFISEqHashKindCustomTreeNode) {
if (custom_s_equal[type_info->type_index] == nullptr) {
// We recursively compare the fields the object
ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* field_info) {
// skip fields that are marked as structural eq hash ignore
Expand Down Expand Up @@ -153,7 +167,6 @@ class StructEqualHandler {
}
});
} else {
static reflection::TypeAttrColumn custom_s_equal = reflection::TypeAttrColumn("__s_equal__");
// run custom equal function defined via __s_equal__ type attribute
if (s_equal_callback_ == nullptr) {
s_equal_callback_ = ffi::Function::FromTyped(
Expand All @@ -179,9 +192,6 @@ class StructEqualHandler {
return success;
});
}
TVM_FFI_ICHECK(custom_s_equal[type_info->type_index] != nullptr)
<< "TypeAttr `__s_equal__` is not registered for type `" << String(type_info->type_key)
<< "`";
success = custom_s_equal[type_info->type_index]
.cast<ffi::Function>()(lhs, rhs, s_equal_callback_)
.cast<bool>();
Expand Down
51 changes: 33 additions & 18 deletions ffi/src/ffi/reflection/structural_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include <tvm/ffi/reflection/structural_hash.h>
#include <tvm/ffi/string.h>

#include <cmath>
#include <limits>
#include <unordered_map>
#include <utility>

Expand All @@ -48,6 +50,13 @@ class StructuralHashHandler {
const TVMFFIAny* src_data = AnyUnsafe::TVMFFIAnyPtrFromAny(src);

if (src_data->type_index < TypeIndex::kTVMFFIStaticObjectBegin) {
// specially handle nan for float, as there can be multiple representations of nan
// make sure they map to the same hash value
if (src_data->type_index == TypeIndex::kTVMFFIFloat && std::isnan(src_data->v_float64)) {
TVMFFIAny temp = *src_data;
temp.v_float64 = std::numeric_limits<double>::quiet_NaN();
return details::StableHashCombine(temp.type_index, temp.v_uint64);
}
// this is POD data, we can just hash the value
return details::StableHashCombine(src_data->type_index, src_data->v_uint64);
}
Expand Down Expand Up @@ -83,9 +92,16 @@ class StructuralHashHandler {
// NOTE: invariant: lhs and rhs are already the same type
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(obj->type_index());
if (type_info->metadata == nullptr) {
// Fallback to pointer hash
return std::hash<const Object*>()(obj.get());
TVM_FFI_THROW(TypeError) << "Type metadata is not set for type `"
<< String(type_info->type_key)
<< "`, so StructuralHash is not supported for this type";
}
if (type_info->metadata->structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported) {
TVM_FFI_THROW(TypeError) << "_type_s_eq_hash_kind is not set for type `"
<< String(type_info->type_key)
<< "`, so StructuralHash is not supported for this type";
}

auto structural_eq_hash_kind = type_info->metadata->structural_eq_hash_kind;
if (structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported) {
// Fallback to pointer hash
Expand All @@ -97,9 +113,11 @@ class StructuralHashHandler {
return it->second;
}

static reflection::TypeAttrColumn custom_s_hash = reflection::TypeAttrColumn("__s_hash__");

// compute the hash value
uint64_t hash_value = obj->GetTypeKeyHash();
if (structural_eq_hash_kind != kTVMFFISEqHashKindCustomTreeNode) {
if (custom_s_hash[type_info->type_index] == nullptr) {
// go over the content and hash the fields
ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) {
// skip fields that are marked as structural eq hash ignore
Expand All @@ -119,22 +137,19 @@ class StructuralHashHandler {
}
});
} else {
static reflection::TypeAttrColumn custom_s_hash = reflection::TypeAttrColumn("__s_hash__");
TVM_FFI_ICHECK(custom_s_hash[type_info->type_index] != nullptr)
<< "TypeAttr `__s_hash__` is not registered for type `" << String(type_info->type_key)
<< "`";
if (s_hash_callback_ == nullptr) {
s_hash_callback_ = ffi::Function::FromTyped([this](AnyView val, bool def_region) {
if (def_region) {
bool allow_free_var = true;
std::swap(allow_free_var, map_free_vars_);
uint64_t hash_value = HashAny(val);
std::swap(allow_free_var, map_free_vars_);
return hash_value;
} else {
return HashAny(val);
}
});
s_hash_callback_ =
ffi::Function::FromTyped([this](AnyView val, uint64_t init_hash, bool def_region) {
if (def_region) {
bool allow_free_var = true;
std::swap(allow_free_var, map_free_vars_);
uint64_t hash_value = HashAny(val);
std::swap(allow_free_var, map_free_vars_);
return details::StableHashCombine(init_hash, hash_value);
} else {
return details::StableHashCombine(init_hash, HashAny(val));
}
});
}
hash_value = custom_s_hash[type_info->type_index]
.cast<ffi::Function>()(obj, hash_value, s_hash_callback_)
Expand Down
11 changes: 6 additions & 5 deletions ffi/tests/cpp/testing_object.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,11 @@ class TCustomFuncObj : public Object {
return true;
}

uint64_t SHash(uint64_t type_key_hash, ffi::TypedFunction<uint64_t(AnyView, bool)> hash) const {
uint64_t hash_value = type_key_hash;
hash_value = tvm::ffi::details::StableHashCombine(hash_value, hash(params, true));
hash_value = tvm::ffi::details::StableHashCombine(hash_value, hash(body, false));
uint64_t SHash(uint64_t init_hash,
ffi::TypedFunction<uint64_t(AnyView, uint64_t, bool)> hash) const {
uint64_t hash_value = init_hash;
hash_value = hash(params, hash_value, true);
hash_value = hash(body, hash_value, false);
return hash_value;
}

Expand All @@ -246,7 +247,7 @@ class TCustomFuncObj : public Object {
}

static constexpr const char* _type_key = "test.CustomFunc";
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindCustomTreeNode;
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TCustomFuncObj, Object);
};

Expand Down
2 changes: 2 additions & 0 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class ConstIntBoundNode : public Object {
*/
static const constexpr int64_t kNegInf = -kPosInf;

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "arith.ConstIntBound";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstIntBoundNode, Object);
};
Expand Down Expand Up @@ -222,6 +223,7 @@ class ModularSetNode : public Object {
return equal(coeff, other->coeff) && equal(base, other->base);
}

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "arith.ModularSet";
TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object);
};
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/arith/int_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class IntGroupBoundsNode : public Object {
hash_reduce(upper);
}

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "arith.IntGroupBounds";
TVM_DECLARE_FINAL_OBJECT_INFO(IntGroupBoundsNode, Object);
Expand Down Expand Up @@ -173,6 +174,7 @@ class IntConstraintsNode : public Object {
hash_reduce(relations);
}

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "arith.IntConstraints";
TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object);
Expand Down Expand Up @@ -238,6 +240,7 @@ class IntConstraintsTransformNode : public Object {
hash_reduce(dst_to_src);
}

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "arith.IntConstraintsTransform";
TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object);
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class IterMarkNode : public Object {
hash_reduce(extent);
}

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const char* _type_key = "arith.IterMark";
Expand Down Expand Up @@ -176,6 +177,7 @@ class IterSplitExprNode : public IterMapExprNode {
hash_reduce(scale);
}

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "arith.IterSplitExpr";
TVM_DECLARE_FINAL_OBJECT_INFO(IterSplitExprNode, IterMapExprNode);
};
Expand Down Expand Up @@ -239,6 +241,7 @@ class IterSumExprNode : public IterMapExprNode {
hash_reduce(base);
}

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "arith.IterSumExpr";
TVM_DECLARE_FINAL_OBJECT_INFO(IterSumExprNode, IterMapExprNode);
};
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class AttrFieldInfoNode : public Object {
}

static constexpr const char* _type_key = "ir.AttrFieldInfo";

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
static constexpr bool _type_has_method_sequal_reduce = false;
static constexpr bool _type_has_method_shash_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
Expand Down Expand Up @@ -122,6 +122,7 @@ class BaseAttrsNode : public Object {
TVM_DLL virtual void InitByPackedArgs(const ffi::PackedArgs& kwargs,
bool allow_unknown = false) = 0;

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const char* _type_key = "ir.Attrs";
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/ir/diagnostic.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class DiagnosticNode : public Object {
equal(this->message, other->message);
}

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "Diagnostic";
TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticNode, Object);
};
Expand Down Expand Up @@ -214,6 +215,7 @@ class DiagnosticContextNode : public Object {
return equal(module, other->module) && equal(diagnostics, other->diagnostics);
}

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "DiagnosticContext";
TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticContextNode, Object);
};
Expand Down
6 changes: 5 additions & 1 deletion include/tvm/ir/env_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ class EnvFuncNode : public Object {

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<EnvFuncNode>().def_ro("name", &EnvFuncNode::name);
// func do not participate in structural equal and hash.
refl::ObjectDef<EnvFuncNode>()
.def_ro("name", &EnvFuncNode::name)
.def_ro("func", &EnvFuncNode::func, refl::AttachFieldFlag::SEqHashIgnore());
}

bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const {
Expand All @@ -64,6 +67,7 @@ class EnvFuncNode : public Object {
hash_reduce(name);
}

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "ir.EnvFunc";
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
Expand Down
20 changes: 18 additions & 2 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,14 @@ class BaseExprNode : public Object {

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<BaseExprNode>().def_ro("span", &BaseExprNode::span, refl::DefaultValue(Span()));
// span do not participate in structural equal and hash.
refl::ObjectDef<BaseExprNode>().def_ro("span", &BaseExprNode::span, refl::DefaultValue(Span()),
refl::AttachFieldFlag::SEqHashIgnore());
}

static constexpr const char* _type_key = "ir.BaseExpr";

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 64;
Expand Down Expand Up @@ -428,7 +431,8 @@ class RelaxExprNode : public BaseExprNode {

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<RelaxExprNode>().def_ro("struct_info_", &RelaxExprNode::struct_info_);
refl::ObjectDef<RelaxExprNode>().def_ro("struct_info_", &RelaxExprNode::struct_info_,
refl::AttachFieldFlag::SEqHashIgnore());
}

static constexpr const char* _type_key = "ir.RelaxExpr";
Expand Down Expand Up @@ -474,6 +478,17 @@ class GlobalVarNode : public RelaxExprNode {
hash_reduce.FreeVarHashImpl(this);
}

bool SEqual(const GlobalVarNode* other,
ffi::TypedFunction<bool(AnyView, AnyView, bool, AnyView)> equal) const {
return equal(name_hint, other->name_hint, false, "name_hint");
}

uint64_t SHash(uint64_t init_hash,
ffi::TypedFunction<uint64_t(AnyView, uint64_t, bool)> hash) const {
return hash(name_hint, init_hash, false);
}

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar;
static constexpr const char* _type_key = "ir.GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelaxExprNode);
};
Expand Down Expand Up @@ -711,6 +726,7 @@ class RangeNode : public Object {
}

static constexpr const char* _type_key = "ir.Range";
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object);
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/ir/global_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ using MemoryScope = String;
class GlobalInfoNode : public Object {
public:
static constexpr const char* _type_key = "ir.GlobalInfo";

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(GlobalInfoNode, Object);
Expand Down
Loading
Loading