Skip to content

Commit 5b51135

Browse files
committed
[FFI] Structural equal and hash based on reflection
This PR add initial support for structural equal and hash via the new reflection mechanism. It will helps us to streamline the structural equality/hash with broader support and clean error reports via AccessPath. It also gives us ability to unify all struct equal/hash registration into the extra meta-data in reflection registration.
1 parent 7707496 commit 5b51135

File tree

15 files changed

+1298
-23
lines changed

15 files changed

+1298
-23
lines changed

ffi/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ add_library(tvm_ffi_objs OBJECT
5757
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/dtype.cc"
5858
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/testing.cc"
5959
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/container.cc"
60+
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/access_path.cc"
61+
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_equal.cc"
62+
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_hash.cc"
6063
)
6164
set_target_properties(
6265
tvm_ffi_objs PROPERTIES

ffi/include/tvm/ffi/c_api.h

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,27 @@
5656
#define TVM_FFI_DLL_EXPORT __attribute__((visibility("default")))
5757
#endif
5858

59+
/*!
60+
* \brief Marks the API as extra c++ api that is defined in cc files.
61+
*
62+
* These APIs are extra features that depend on, but are not required to
63+
* support essential core functionality, such as function calling and object
64+
* access.
65+
*
66+
* They are implemented in cc files to reduce compile-time overhead.
67+
* The input/output only uses POD/Any/ObjectRef for ABI stability.
68+
* However, these extra APIs may have an issue across MSVC/Itanium ABI,
69+
*
70+
* Related features are also available through reflection based function
71+
* that is fully based on C API
72+
*
73+
* The project aims to minimize the number of extra C++ APIs and only
74+
* restrict the use to non-core functionalities.
75+
*/
76+
#ifndef TVM_FFI_EXTRA_CXX_API
77+
#define TVM_FFI_EXTRA_CXX_API TVM_FFI_DLL
78+
#endif
79+
5980
#ifdef __cplusplus
6081
extern "C" {
6182
#endif
@@ -326,12 +347,89 @@ typedef enum {
326347
kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1,
327348
/*! \brief The field is a static method. */
328349
kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2,
350+
/*!
351+
* \brief The field should be ignored when performing structural eq/hash
352+
*
353+
* This is an optional meta-data for structural eq/hash.
354+
*/
355+
kTVMFFIFieldFlagBitMaskSEqHashIgnore = 1 << 3,
356+
/*!
357+
* \brief The field enters a def region where var can be defined/matched.
358+
*
359+
* This is an optional meta-data for structural eq/hash.
360+
*/
361+
kTVMFFIFieldFlagBitMaskSEqHashDef = 1 << 4,
329362
#ifdef __cplusplus
330363
};
331364
#else
332365
} TVMFFIFieldFlagBitMask;
333366
#endif
334367

368+
/*!
369+
* \brief Optional meta-data for structural eq/hash.
370+
*
371+
* This meta-data is only useful when we want to leverage the information
372+
* to perform richer semantics aware structural comparison and hash.
373+
* It can be safely ignored if such information is not needed.
374+
*
375+
* The meta-data record comparison method in tree node and DAG node.
376+
*
377+
* \code
378+
* x = VarNode()
379+
* v0 = AddNode(x, 1)
380+
* v1 = AddNode(x, 1)
381+
* v2 = AddNode(v0, v0)
382+
* v3 = AddNode(v1, v0)
383+
* \endcode
384+
*
385+
* Consider the construct sequence of AddNode below,
386+
* if AddNode is treated as a tree node, then v2 and v3
387+
* structural equals to each other, but if AddNode is
388+
* treated as a DAG node, then v2 and v3 does not
389+
* structural equals to each other.
390+
*/
391+
#ifdef __cplusplus
392+
enum TVMFFISEqHashKind : int32_t {
393+
#else
394+
typedef enum {
395+
#endif
396+
/*! \brief Do not support structural eq/hash. */
397+
kTVMFFISEqHashKindUnsupported = 0,
398+
/*!
399+
* \brief The object be compared as a tree node.
400+
*/
401+
kTVMFFISEqHashKindTreeNode = 1,
402+
/*!
403+
* \brief The object is treated as a free variable that can be mapped
404+
* to another free variable in the definition region.
405+
*/
406+
kTVMFFISEqHashKindFreeVar = 2,
407+
/*!
408+
* \brief The field should be compared as a DAG node.
409+
*/
410+
kTVMFFISEqHashKindDAGNode = 3,
411+
/*!
412+
* \brief The object is treated as a constant tree node.
413+
*
414+
* Same as tree node, but the object does not contain free var
415+
* as any of its nested children.
416+
*
417+
* That means we can use pointer equality for equality.
418+
*/
419+
kTVMFFISEqHashKindConstTreeNode = 4,
420+
/*!
421+
* \brief One can simply use pointer equality for equality.
422+
*
423+
* This is useful for "singleton"-style object that can
424+
* is only an unique copy of each value.
425+
*/
426+
kTVMFFISEqHashKindUniqueInstance = 5,
427+
#ifdef __cplusplus
428+
};
429+
#else
430+
} TVMFFISEqHashKind;
431+
#endif
432+
335433
/*!
336434
* \brief Information support for optional object reflection.
337435
*/
@@ -431,7 +529,11 @@ typedef struct {
431529
*
432530
* This field is set optional and set to 0 if not registered.
433531
*/
434-
int64_t total_size;
532+
int32_t total_size;
533+
/*!
534+
* \brief Optional meta-data for structural eq/hash.
535+
*/
536+
TVMFFISEqHashKind structural_eq_hash_kind;
435537
} TVMFFITypeExtraInfo;
436538

437539
/*!

ffi/include/tvm/ffi/object.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ class Object {
212212
static constexpr int32_t _type_index = TypeIndex::kTVMFFIObject;
213213
// the static type depth of the class
214214
static constexpr int32_t _type_depth = 0;
215+
// the structural equality and hash kind of the type
216+
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindUnsupported;
215217
// extra fields used by plug-ins for attribute visiting
216218
// and structural information
217219
static constexpr const bool _type_has_method_sequal_reduce = false;
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
/*!
20+
* \file tvm/ffi/reflection/registry.h
21+
* \brief Registry of reflection metadata.
22+
*/
23+
#ifndef TVM_FFI_REFLECTION_ACCESS_PATH_H_
24+
#define TVM_FFI_REFLECTION_ACCESS_PATH_H_
25+
26+
#include <tvm/ffi/any.h>
27+
#include <tvm/ffi/c_api.h>
28+
#include <tvm/ffi/container/array.h>
29+
#include <tvm/ffi/container/tuple.h>
30+
#include <tvm/ffi/reflection/registry.h>
31+
32+
namespace tvm {
33+
namespace ffi {
34+
namespace reflection {
35+
36+
enum class AccessKind : int32_t {
37+
kObjectField = 0,
38+
kArrayIndex = 1,
39+
kMapKey = 2,
40+
// the following two are used for error reporting when
41+
// the supposed access field is not available
42+
kArrayIndexMissing = 3,
43+
kMapKeyMissing = 4,
44+
};
45+
46+
/*!
47+
* \brief Represent a single step in object field, map key, array index access.
48+
*/
49+
class AccessStepObj : public Object {
50+
public:
51+
/*!
52+
* \brief The kind of the access pattern.
53+
*/
54+
AccessKind kind;
55+
/*!
56+
* \brief The access key
57+
* \note for array access, it will always be integer
58+
* for field access, it will be string
59+
*/
60+
Any key;
61+
62+
AccessStepObj(AccessKind kind, Any key) : kind(kind), key(key) {}
63+
64+
static void RegisterReflection() {
65+
namespace refl = tvm::ffi::reflection;
66+
refl::ObjectDef<AccessStepObj>()
67+
.def_ro("kind", &AccessStepObj::kind)
68+
.def_ro("key", &AccessStepObj::key);
69+
}
70+
71+
static constexpr const char* _type_key = "tvm.ffi.reflection.AccessStep";
72+
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode;
73+
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AccessStepObj, Object);
74+
};
75+
76+
/*!
77+
* \brief ObjectRef class of AccessStepObj.
78+
*
79+
* \sa AccessStepObj
80+
*/
81+
class AccessStep : public ObjectRef {
82+
public:
83+
AccessStep(AccessKind kind, Any key) : ObjectRef(make_object<AccessStepObj>(kind, key)) {}
84+
85+
static AccessStep ObjectField(String field_name) {
86+
return AccessStep(AccessKind::kObjectField, field_name);
87+
}
88+
89+
static AccessStep ArrayIndex(int64_t index) { return AccessStep(AccessKind::kArrayIndex, index); }
90+
91+
static AccessStep ArrayIndexMissing(int64_t index) {
92+
return AccessStep(AccessKind::kArrayIndexMissing, index);
93+
}
94+
95+
static AccessStep MapKey(Any key) { return AccessStep(AccessKind::kMapKey, key); }
96+
97+
static AccessStep MapKeyMissing(Any key) { return AccessStep(AccessKind::kMapKeyMissing, key); }
98+
99+
TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AccessStep, ObjectRef, AccessStepObj);
100+
};
101+
102+
using AccessPath = Array<AccessStep>;
103+
using AccessPathPair = Tuple<AccessPath, AccessPath>;
104+
105+
} // namespace reflection
106+
} // namespace ffi
107+
} // namespace tvm
108+
#endif // TVM_FFI_REFLECTION_ACCESS_PATH_H_

ffi/include/tvm/ffi/reflection/registry.h

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,39 @@ class DefaultValue : public FieldInfoTrait {
5555
Any value_;
5656
};
5757

58+
/*
59+
* \brief Trait that can be used to attach field flag
60+
*/
61+
class AttachFieldFlag : public FieldInfoTrait {
62+
public:
63+
/*!
64+
* \brief Attach a field flag to the field
65+
*
66+
* \param flag The flag to be set
67+
*
68+
* \return The trait object.
69+
*/
70+
explicit AttachFieldFlag(int32_t flag) : flag_(flag) {}
71+
72+
/*!
73+
* \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDef
74+
*/
75+
TVM_FFI_INLINE static AttachFieldFlag SEqHashDef() {
76+
return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDef);
77+
}
78+
/*!
79+
* \brief Attach kTVMFFIFieldFlagBitMaskSEqHashIgnore
80+
*/
81+
TVM_FFI_INLINE static AttachFieldFlag SEqHashIgnore() {
82+
return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashIgnore);
83+
}
84+
85+
TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const { info->flags |= flag_; }
86+
87+
private:
88+
int32_t flag_;
89+
};
90+
5891
/*!
5992
* \brief Get the byte offset of a class member field.
6093
*
@@ -83,7 +116,11 @@ class ReflectionDefBase {
83116
template <typename T>
84117
static int FieldSetter(void* field, const TVMFFIAny* value) {
85118
TVM_FFI_SAFE_CALL_BEGIN();
86-
*reinterpret_cast<T*>(field) = AnyView::CopyFromTVMFFIAny(*value).cast<T>();
119+
if constexpr (std::is_same_v<T, Any>) {
120+
*reinterpret_cast<T*>(field) = AnyView::CopyFromTVMFFIAny(*value);
121+
} else {
122+
*reinterpret_cast<T*>(field) = AnyView::CopyFromTVMFFIAny(*value).cast<T>();
123+
}
87124
TVM_FFI_SAFE_CALL_END();
88125
}
89126

@@ -346,6 +383,7 @@ class ObjectDef : public ReflectionDefBase {
346383
void RegisterExtraInfo(ExtraArgs&&... extra_args) {
347384
TVMFFITypeExtraInfo info;
348385
info.total_size = sizeof(Class);
386+
info.structural_eq_hash_kind = Class::_type_s_eq_hash_kind;
349387
info.creator = nullptr;
350388
info.doc = TVMFFIByteArray{nullptr, 0};
351389
if constexpr (std::is_default_constructible_v<Class>) {

0 commit comments

Comments
 (0)