Skip to content

Commit e05064f

Browse files
authored
[FFI] Structural equal and hash based on reflection (#18156)
This PR add initial support for structural equal and hash via the new reflection mechanism. It will helps us to streamline the structural equality/hash with broader support and clean error reports via AccessPath. It also gives us ability to unify all struct equal/hash registration into the extra meta-data in reflection registration.
1 parent 89f9573 commit e05064f

File tree

15 files changed

+1298
-23
lines changed

15 files changed

+1298
-23
lines changed

ffi/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ add_library(tvm_ffi_objs OBJECT
5757
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/dtype.cc"
5858
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/testing.cc"
5959
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/container.cc"
60+
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/access_path.cc"
61+
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_equal.cc"
62+
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_hash.cc"
6063
)
6164
set_target_properties(
6265
tvm_ffi_objs PROPERTIES

ffi/include/tvm/ffi/c_api.h

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,27 @@
5656
#define TVM_FFI_DLL_EXPORT __attribute__((visibility("default")))
5757
#endif
5858

59+
/*!
60+
* \brief Marks the API as extra c++ api that is defined in cc files.
61+
*
62+
* These APIs are extra features that depend on, but are not required to
63+
* support essential core functionality, such as function calling and object
64+
* access.
65+
*
66+
* They are implemented in cc files to reduce compile-time overhead.
67+
* The input/output only uses POD/Any/ObjectRef for ABI stability.
68+
* However, these extra APIs may have an issue across MSVC/Itanium ABI,
69+
*
70+
* Related features are also available through reflection based function
71+
* that is fully based on C API
72+
*
73+
* The project aims to minimize the number of extra C++ APIs and only
74+
* restrict the use to non-core functionalities.
75+
*/
76+
#ifndef TVM_FFI_EXTRA_CXX_API
77+
#define TVM_FFI_EXTRA_CXX_API TVM_FFI_DLL
78+
#endif
79+
5980
#ifdef __cplusplus
6081
extern "C" {
6182
#endif
@@ -326,12 +347,89 @@ typedef enum {
326347
kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1,
327348
/*! \brief The field is a static method. */
328349
kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2,
350+
/*!
351+
* \brief The field should be ignored when performing structural eq/hash
352+
*
353+
* This is an optional meta-data for structural eq/hash.
354+
*/
355+
kTVMFFIFieldFlagBitMaskSEqHashIgnore = 1 << 3,
356+
/*!
357+
* \brief The field enters a def region where var can be defined/matched.
358+
*
359+
* This is an optional meta-data for structural eq/hash.
360+
*/
361+
kTVMFFIFieldFlagBitMaskSEqHashDef = 1 << 4,
329362
#ifdef __cplusplus
330363
};
331364
#else
332365
} TVMFFIFieldFlagBitMask;
333366
#endif
334367

368+
/*!
369+
* \brief Optional meta-data for structural eq/hash.
370+
*
371+
* This meta-data is only useful when we want to leverage the information
372+
* to perform richer semantics aware structural comparison and hash.
373+
* It can be safely ignored if such information is not needed.
374+
*
375+
* The meta-data record comparison method in tree node and DAG node.
376+
*
377+
* \code
378+
* x = VarNode()
379+
* v0 = AddNode(x, 1)
380+
* v1 = AddNode(x, 1)
381+
* v2 = AddNode(v0, v0)
382+
* v3 = AddNode(v1, v0)
383+
* \endcode
384+
*
385+
* Consider the construct sequence of AddNode below,
386+
* if AddNode is treated as a tree node, then v2 and v3
387+
* structural equals to each other, but if AddNode is
388+
* treated as a DAG node, then v2 and v3 does not
389+
* structural equals to each other.
390+
*/
391+
#ifdef __cplusplus
392+
enum TVMFFISEqHashKind : int32_t {
393+
#else
394+
typedef enum {
395+
#endif
396+
/*! \brief Do not support structural eq/hash. */
397+
kTVMFFISEqHashKindUnsupported = 0,
398+
/*!
399+
* \brief The object be compared as a tree node.
400+
*/
401+
kTVMFFISEqHashKindTreeNode = 1,
402+
/*!
403+
* \brief The object is treated as a free variable that can be mapped
404+
* to another free variable in the definition region.
405+
*/
406+
kTVMFFISEqHashKindFreeVar = 2,
407+
/*!
408+
* \brief The field should be compared as a DAG node.
409+
*/
410+
kTVMFFISEqHashKindDAGNode = 3,
411+
/*!
412+
* \brief The object is treated as a constant tree node.
413+
*
414+
* Same as tree node, but the object does not contain free var
415+
* as any of its nested children.
416+
*
417+
* That means we can use pointer equality for equality.
418+
*/
419+
kTVMFFISEqHashKindConstTreeNode = 4,
420+
/*!
421+
* \brief One can simply use pointer equality for equality.
422+
*
423+
* This is useful for "singleton"-style object that can
424+
* is only an unique copy of each value.
425+
*/
426+
kTVMFFISEqHashKindUniqueInstance = 5,
427+
#ifdef __cplusplus
428+
};
429+
#else
430+
} TVMFFISEqHashKind;
431+
#endif
432+
335433
/*!
336434
* \brief Information support for optional object reflection.
337435
*/
@@ -431,7 +529,11 @@ typedef struct {
431529
*
432530
* This field is set optional and set to 0 if not registered.
433531
*/
434-
int64_t total_size;
532+
int32_t total_size;
533+
/*!
534+
* \brief Optional meta-data for structural eq/hash.
535+
*/
536+
TVMFFISEqHashKind structural_eq_hash_kind;
435537
} TVMFFITypeExtraInfo;
436538

437539
/*!

ffi/include/tvm/ffi/object.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ class Object {
212212
static constexpr int32_t _type_index = TypeIndex::kTVMFFIObject;
213213
// the static type depth of the class
214214
static constexpr int32_t _type_depth = 0;
215+
// the structural equality and hash kind of the type
216+
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindUnsupported;
215217
// extra fields used by plug-ins for attribute visiting
216218
// and structural information
217219
static constexpr const bool _type_has_method_sequal_reduce = false;
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
/*!
20+
* \file tvm/ffi/reflection/registry.h
21+
* \brief Registry of reflection metadata.
22+
*/
23+
#ifndef TVM_FFI_REFLECTION_ACCESS_PATH_H_
24+
#define TVM_FFI_REFLECTION_ACCESS_PATH_H_
25+
26+
#include <tvm/ffi/any.h>
27+
#include <tvm/ffi/c_api.h>
28+
#include <tvm/ffi/container/array.h>
29+
#include <tvm/ffi/container/tuple.h>
30+
#include <tvm/ffi/reflection/registry.h>
31+
32+
namespace tvm {
33+
namespace ffi {
34+
namespace reflection {
35+
36+
enum class AccessKind : int32_t {
37+
kObjectField = 0,
38+
kArrayIndex = 1,
39+
kMapKey = 2,
40+
// the following two are used for error reporting when
41+
// the supposed access field is not available
42+
kArrayIndexMissing = 3,
43+
kMapKeyMissing = 4,
44+
};
45+
46+
/*!
47+
* \brief Represent a single step in object field, map key, array index access.
48+
*/
49+
class AccessStepObj : public Object {
50+
public:
51+
/*!
52+
* \brief The kind of the access pattern.
53+
*/
54+
AccessKind kind;
55+
/*!
56+
* \brief The access key
57+
* \note for array access, it will always be integer
58+
* for field access, it will be string
59+
*/
60+
Any key;
61+
62+
AccessStepObj(AccessKind kind, Any key) : kind(kind), key(key) {}
63+
64+
static void RegisterReflection() {
65+
namespace refl = tvm::ffi::reflection;
66+
refl::ObjectDef<AccessStepObj>()
67+
.def_ro("kind", &AccessStepObj::kind)
68+
.def_ro("key", &AccessStepObj::key);
69+
}
70+
71+
static constexpr const char* _type_key = "tvm.ffi.reflection.AccessStep";
72+
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode;
73+
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AccessStepObj, Object);
74+
};
75+
76+
/*!
77+
* \brief ObjectRef class of AccessStepObj.
78+
*
79+
* \sa AccessStepObj
80+
*/
81+
class AccessStep : public ObjectRef {
82+
public:
83+
AccessStep(AccessKind kind, Any key) : ObjectRef(make_object<AccessStepObj>(kind, key)) {}
84+
85+
static AccessStep ObjectField(String field_name) {
86+
return AccessStep(AccessKind::kObjectField, field_name);
87+
}
88+
89+
static AccessStep ArrayIndex(int64_t index) { return AccessStep(AccessKind::kArrayIndex, index); }
90+
91+
static AccessStep ArrayIndexMissing(int64_t index) {
92+
return AccessStep(AccessKind::kArrayIndexMissing, index);
93+
}
94+
95+
static AccessStep MapKey(Any key) { return AccessStep(AccessKind::kMapKey, key); }
96+
97+
static AccessStep MapKeyMissing(Any key) { return AccessStep(AccessKind::kMapKeyMissing, key); }
98+
99+
TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AccessStep, ObjectRef, AccessStepObj);
100+
};
101+
102+
using AccessPath = Array<AccessStep>;
103+
using AccessPathPair = Tuple<AccessPath, AccessPath>;
104+
105+
} // namespace reflection
106+
} // namespace ffi
107+
} // namespace tvm
108+
#endif // TVM_FFI_REFLECTION_ACCESS_PATH_H_

ffi/include/tvm/ffi/reflection/registry.h

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,39 @@ class DefaultValue : public FieldInfoTrait {
5555
Any value_;
5656
};
5757

58+
/*
59+
* \brief Trait that can be used to attach field flag
60+
*/
61+
class AttachFieldFlag : public FieldInfoTrait {
62+
public:
63+
/*!
64+
* \brief Attach a field flag to the field
65+
*
66+
* \param flag The flag to be set
67+
*
68+
* \return The trait object.
69+
*/
70+
explicit AttachFieldFlag(int32_t flag) : flag_(flag) {}
71+
72+
/*!
73+
* \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDef
74+
*/
75+
TVM_FFI_INLINE static AttachFieldFlag SEqHashDef() {
76+
return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDef);
77+
}
78+
/*!
79+
* \brief Attach kTVMFFIFieldFlagBitMaskSEqHashIgnore
80+
*/
81+
TVM_FFI_INLINE static AttachFieldFlag SEqHashIgnore() {
82+
return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashIgnore);
83+
}
84+
85+
TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const { info->flags |= flag_; }
86+
87+
private:
88+
int32_t flag_;
89+
};
90+
5891
/*!
5992
* \brief Get the byte offset of a class member field.
6093
*
@@ -83,7 +116,11 @@ class ReflectionDefBase {
83116
template <typename T>
84117
static int FieldSetter(void* field, const TVMFFIAny* value) {
85118
TVM_FFI_SAFE_CALL_BEGIN();
86-
*reinterpret_cast<T*>(field) = AnyView::CopyFromTVMFFIAny(*value).cast<T>();
119+
if constexpr (std::is_same_v<T, Any>) {
120+
*reinterpret_cast<T*>(field) = AnyView::CopyFromTVMFFIAny(*value);
121+
} else {
122+
*reinterpret_cast<T*>(field) = AnyView::CopyFromTVMFFIAny(*value).cast<T>();
123+
}
87124
TVM_FFI_SAFE_CALL_END();
88125
}
89126

@@ -346,6 +383,7 @@ class ObjectDef : public ReflectionDefBase {
346383
void RegisterExtraInfo(ExtraArgs&&... extra_args) {
347384
TVMFFITypeExtraInfo info;
348385
info.total_size = sizeof(Class);
386+
info.structural_eq_hash_kind = Class::_type_s_eq_hash_kind;
349387
info.creator = nullptr;
350388
info.doc = TVMFFIByteArray{nullptr, 0};
351389
if constexpr (std::is_default_constructible_v<Class>) {

0 commit comments

Comments
 (0)