Skip to content

Commit 3935401

Browse files
committed
[FFI][REFACTOR] Migrate the Save/Load JSON to the new reflection
This PR migrates the Save/Load JSON to the new reflection based mechanism. This is a breaking change that updates the the JSON format to ffi/extra/serialization to handle the serialization, see the json graph schema comment in ffi/extra/serialization.h for the format, which roughly aligns with the old style. After this change, we no longer need node/reflection and reflection vtable. We can also phase out TVM_REGISTER_NODE and TVM_REGISTER_OBJECT to have a single place that defines the reflection.
1 parent 9c27523 commit 3935401

File tree

224 files changed

+350
-2060
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

224 files changed

+350
-2060
lines changed
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
/*!
20+
* \file tvm/ffi/reflection/creator.h
21+
* \brief Reflection-based creator to create objects from type key and fields.
22+
*/
23+
#ifndef TVM_FFI_REFLECTION_CREATOR_H_
24+
#define TVM_FFI_REFLECTION_CREATOR_H_
25+
26+
#include <tvm/ffi/any.h>
27+
#include <tvm/ffi/container/map.h>
28+
#include <tvm/ffi/reflection/accessor.h>
29+
#include <tvm/ffi/string.h>
30+
31+
namespace tvm {
32+
namespace ffi {
33+
namespace reflection {
34+
/*!
35+
* \brief helper wrapper class of TVMFFITypeInfo to create object based on reflection.
36+
*/
37+
class ObjectCreator {
38+
public:
39+
explicit ObjectCreator(std::string_view type_key)
40+
: ObjectCreator(TVMFFIGetTypeInfo(TypeKeyToIndex(type_key))) {}
41+
42+
explicit ObjectCreator(const TVMFFITypeInfo* type_info) : type_info_(type_info) {
43+
int32_t type_index = type_info->type_index;
44+
if (type_info->metadata == nullptr) {
45+
TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
46+
<< "` does not have reflection registered";
47+
}
48+
if (type_info->metadata->creator == nullptr) {
49+
TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
50+
<< "` does not support default constructor, "
51+
<< "as a result cannot be created via reflection";
52+
}
53+
}
54+
55+
/**
56+
* \brief Create an object from a map of fields.
57+
* \param fields The fields of the object.
58+
* \return The created object.
59+
*/
60+
Any operator()(const Map<String, Any>& fields) const {
61+
TVMFFIObjectHandle handle;
62+
TVM_FFI_CHECK_SAFE_CALL(type_info_->metadata->creator(&handle));
63+
ObjectPtr<Object> ptr =
64+
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
65+
size_t match_field_count = 0;
66+
ForEachFieldInfo(type_info_, [&](const TVMFFIFieldInfo* field_info) {
67+
String field_name(field_info->name);
68+
void* field_addr = reinterpret_cast<char*>(ptr.get()) + field_info->offset;
69+
if (fields.count(field_name) != 0) {
70+
Any field_value = fields[field_name];
71+
field_info->setter(field_addr, reinterpret_cast<const TVMFFIAny*>(&field_value));
72+
++match_field_count;
73+
} else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) {
74+
field_info->setter(field_addr, &(field_info->default_value));
75+
} else {
76+
TVM_FFI_THROW(TypeError) << "Required field `"
77+
<< String(field_info->name.data, field_info->name.size)
78+
<< "` not set in type `"
79+
<< String(type_info_->type_key.data, type_info_->type_key.size)
80+
<< "`";
81+
}
82+
});
83+
if (match_field_count == fields.size()) return ObjectRef(ptr);
84+
// report error that checks if contains extra fields that are not in the type
85+
auto check_field_name = [&](const String& field_name) {
86+
bool found = false;
87+
ForEachFieldInfoWithEarlyStop(type_info_, [&](const TVMFFIFieldInfo* field_info) {
88+
if (field_name.compare(field_info->name) == 0) {
89+
found = true;
90+
return true;
91+
}
92+
return false;
93+
});
94+
return found;
95+
};
96+
for (const auto& [field_name, _] : fields) {
97+
if (!check_field_name(field_name)) {
98+
TVM_FFI_THROW(TypeError) << "Type `"
99+
<< String(type_info_->type_key.data, type_info_->type_key.size)
100+
<< "` does not have field `" << field_name << "`";
101+
}
102+
}
103+
TVM_FFI_UNREACHABLE();
104+
}
105+
106+
private:
107+
const TVMFFITypeInfo* type_info_;
108+
};
109+
} // namespace reflection
110+
} // namespace ffi
111+
} // namespace tvm
112+
#endif // TVM_FFI_REFLECTION_CREATOR_H_

ffi/src/ffi/extra/serialization.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <tvm/ffi/any.h>
2525
#include <tvm/ffi/container/array.h>
2626
#include <tvm/ffi/container/map.h>
27+
#include <tvm/ffi/container/shape.h>
2728
#include <tvm/ffi/dtype.h>
2829
#include <tvm/ffi/error.h>
2930
#include <tvm/ffi/extra/base64.h>
@@ -119,6 +120,12 @@ class ObjectGraphSerializer {
119120
node.Set("data", CreateMapData(map));
120121
break;
121122
}
123+
case TypeIndex::kTVMFFIShape: {
124+
ffi::Shape shape = details::AnyUnsafe::CopyFromAnyViewAfterCheck<ffi::Shape>(value);
125+
node.Set("type", ffi::StaticTypeKey::kTVMFFIShape);
126+
node.Set("data", Array<int64_t>(shape->data, shape->data + shape->size));
127+
break;
128+
}
122129
default: {
123130
if (value.type_index() >= TypeIndex::kTVMFFIStaticObjectBegin) {
124131
// serialize type key since type index is runtime dependent
@@ -157,10 +164,10 @@ class ObjectGraphSerializer {
157164

158165
// create the data for the object, if the type has a custom data to json function,
159166
// use it. otherwise, we go over the fields and create the data.
160-
json::Object CreateObjectData(const Any& value) {
167+
json::Value CreateObjectData(const Any& value) {
161168
static reflection::TypeAttrColumn data_to_json = reflection::TypeAttrColumn("__data_to_json__");
162169
if (data_to_json[value.type_index()] != nullptr) {
163-
return data_to_json[value.type_index()].cast<Function>()(value).cast<json::Object>();
170+
return data_to_json[value.type_index()].cast<Function>()(value);
164171
}
165172
// NOTE: invariant: lhs and rhs are already the same type
166173
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(value.type_index());
@@ -286,6 +293,10 @@ class ObjectGraphDeserializer {
286293
case TypeIndex::kTVMFFIArray: {
287294
return DecodeArrayData(node["data"].cast<json::Array>());
288295
}
296+
case TypeIndex::kTVMFFIShape: {
297+
Array<int64_t> data = node["data"].cast<Array<int64_t>>();
298+
return ffi::Shape(data);
299+
}
289300
default: {
290301
return DecodeObjectData(type_index, node["data"]);
291302
}

ffi/tests/cpp/extra/test_serialization.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <gtest/gtest.h>
2020
#include <tvm/ffi/container/array.h>
2121
#include <tvm/ffi/container/map.h>
22+
#include <tvm/ffi/container/shape.h>
2223
#include <tvm/ffi/dtype.h>
2324
#include <tvm/ffi/extra/serialization.h>
2425
#include <tvm/ffi/extra/structural_equal.h>
@@ -271,6 +272,23 @@ TEST(Serialization, Maps) {
271272
EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_duplicated), duplicated_map));
272273
}
273274

275+
TEST(Serialization, Shapes) {
276+
Shape empty_shape;
277+
278+
json::Object expected_empty_shape = json::Object{
279+
{"root_index", 0},
280+
{"nodes", json::Array{json::Object{{"type", "ffi.Shape"}, {"data", json::Array{}}}}}};
281+
EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_shape), expected_empty_shape));
282+
EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty_shape), empty_shape));
283+
284+
Shape shape({1, 2, 3});
285+
json::Object expected_shape = json::Object{
286+
{"root_index", 0},
287+
{"nodes", json::Array{json::Object{{"type", "ffi.Shape"}, {"data", json::Array{1, 2, 3}}}}}};
288+
EXPECT_TRUE(StructuralEqual()(ToJSONGraph(shape), expected_shape));
289+
EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_shape), shape));
290+
}
291+
274292
TEST(Serialization, TestObjectVar) {
275293
TVar x = TVar("x");
276294
json::Object expected_x = json::Object{

ffi/tests/cpp/test_reflection_accessor.cc renamed to ffi/tests/cpp/test_reflection.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <tvm/ffi/container/map.h>
2222
#include <tvm/ffi/object.h>
2323
#include <tvm/ffi/reflection/accessor.h>
24+
#include <tvm/ffi/reflection/creator.h>
2425
#include <tvm/ffi/reflection/registry.h>
2526
#include <tvm/ffi/string.h>
2627

@@ -159,4 +160,9 @@ TEST(Reflection, FuncRegister) {
159160
EXPECT_EQ(fget_value(a).cast<int>(), 12);
160161
}
161162

163+
TEST(Reflection, ObjectCreator) {
164+
namespace refl = tvm::ffi::reflection;
165+
refl::ObjectCreator creator("test.Int");
166+
EXPECT_EQ(creator(Map<String, Any>({{"value", 1}})).cast<TInt>()->value, 1);
167+
}
162168
} // namespace

ffi/tests/cpp/testing_object.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class TIntObj : public TNumberObj {
5959
public:
6060
int64_t value;
6161

62+
TIntObj() = default;
6263
TIntObj(int64_t value) : value(value) {}
6364

6465
int64_t GetValue() const { return value; }

include/tvm/ir/env_func.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
#include <tvm/ffi/function.h>
2828
#include <tvm/ffi/reflection/registry.h>
29-
#include <tvm/node/reflection.h>
29+
#include <tvm/node/node.h>
3030

3131
#include <string>
3232
#include <utility>

include/tvm/ir/instrument.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
#include <tvm/ffi/reflection/registry.h>
3030
#include <tvm/ffi/string.h>
31-
#include <tvm/node/reflection.h>
31+
#include <tvm/node/node.h>
3232

3333
#include <utility>
3434
#include <vector>

include/tvm/ir/transform.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
#define TVM_IR_TRANSFORM_H_
5858

5959
#include <tvm/ffi/container/array.h>
60+
#include <tvm/ffi/reflection/creator.h>
6061
#include <tvm/ffi/reflection/registry.h>
6162
#include <tvm/ffi/string.h>
6263
#include <tvm/ir/diagnostic.h>
@@ -244,11 +245,10 @@ class PassContext : public ObjectRef {
244245
// NOTE: we could further update the function later.
245246
if constexpr (std::is_base_of_v<ObjectRef, ValueType>) {
246247
int32_t tindex = ffi::TypeToRuntimeTypeIndex<ValueType>::v();
247-
auto* reflection = ReflectionVTable::Global();
248248
auto type_key = ffi::TypeIndexToTypeKey(tindex);
249249
auto legalization = [=](ffi::Any value) -> ffi::Any {
250250
if (auto opt_map = value.try_cast<Map<String, ffi::Any>>()) {
251-
return reflection->CreateObject(type_key, opt_map.value());
251+
return ffi::reflection::ObjectCreator(type_key)(opt_map.value());
252252
} else {
253253
auto opt_val = value.try_cast<ValueType>();
254254
if (!opt_val.has_value()) {

include/tvm/meta_schedule/arg_info.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
#include <tvm/ffi/reflection/registry.h>
2424
#include <tvm/ir/module.h>
2525
#include <tvm/node/node.h>
26-
#include <tvm/node/reflection.h>
2726
#include <tvm/runtime/data_type.h>
2827
#include <tvm/runtime/object.h>
2928
#include <tvm/tir/function.h>

include/tvm/meta_schedule/builder.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
#include <tvm/ffi/reflection/registry.h>
2727
#include <tvm/ffi/string.h>
2828
#include <tvm/ir/module.h>
29-
#include <tvm/node/reflection.h>
3029
#include <tvm/runtime/ndarray.h>
3130
#include <tvm/runtime/object.h>
3231
#include <tvm/target/target.h>

0 commit comments

Comments
 (0)