Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions ffi/include/tvm/ffi/reflection/creator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ffi/reflection/creator.h
* \brief Reflection-based creator to create objects from type key and fields.
*/
#ifndef TVM_FFI_REFLECTION_CREATOR_H_
#define TVM_FFI_REFLECTION_CREATOR_H_

#include <tvm/ffi/any.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/string.h>

namespace tvm {
namespace ffi {
namespace reflection {
/*!
* \brief helper wrapper class of TVMFFITypeInfo to create object based on reflection.
*/
class ObjectCreator {
public:
explicit ObjectCreator(std::string_view type_key)
: ObjectCreator(TVMFFIGetTypeInfo(TypeKeyToIndex(type_key))) {}

explicit ObjectCreator(const TVMFFITypeInfo* type_info) : type_info_(type_info) {
int32_t type_index = type_info->type_index;
if (type_info->metadata == nullptr) {
TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
<< "` does not have reflection registered";
}
if (type_info->metadata->creator == nullptr) {
TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
<< "` does not support default constructor, "
<< "as a result cannot be created via reflection";
}
}

/**
* \brief Create an object from a map of fields.
* \param fields The fields of the object.
* \return The created object.
*/
Any operator()(const Map<String, Any>& fields) const {
TVMFFIObjectHandle handle;
TVM_FFI_CHECK_SAFE_CALL(type_info_->metadata->creator(&handle));
ObjectPtr<Object> ptr =
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
size_t match_field_count = 0;
ForEachFieldInfo(type_info_, [&](const TVMFFIFieldInfo* field_info) {
String field_name(field_info->name);
void* field_addr = reinterpret_cast<char*>(ptr.get()) + field_info->offset;
if (fields.count(field_name) != 0) {
Any field_value = fields[field_name];
field_info->setter(field_addr, reinterpret_cast<const TVMFFIAny*>(&field_value));
++match_field_count;
} else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) {
field_info->setter(field_addr, &(field_info->default_value));
} else {
TVM_FFI_THROW(TypeError) << "Required field `"
<< String(field_info->name.data, field_info->name.size)
<< "` not set in type `"
<< String(type_info_->type_key.data, type_info_->type_key.size)
<< "`";
}
});
if (match_field_count == fields.size()) return ObjectRef(ptr);
// report error that checks if contains extra fields that are not in the type
auto check_field_name = [&](const String& field_name) {
bool found = false;
ForEachFieldInfoWithEarlyStop(type_info_, [&](const TVMFFIFieldInfo* field_info) {
if (field_name.compare(field_info->name) == 0) {
found = true;
return true;
}
return false;
});
return found;
};
for (const auto& [field_name, _] : fields) {
if (!check_field_name(field_name)) {
TVM_FFI_THROW(TypeError) << "Type `"
<< String(type_info_->type_key.data, type_info_->type_key.size)
<< "` does not have field `" << field_name << "`";
}
}
TVM_FFI_UNREACHABLE();
}

private:
const TVMFFITypeInfo* type_info_;
};
} // namespace reflection
} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_REFLECTION_CREATOR_H_
15 changes: 13 additions & 2 deletions ffi/src/ffi/extra/serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/ffi/any.h>
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/shape.h>
#include <tvm/ffi/dtype.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/base64.h>
Expand Down Expand Up @@ -119,6 +120,12 @@ class ObjectGraphSerializer {
node.Set("data", CreateMapData(map));
break;
}
case TypeIndex::kTVMFFIShape: {
ffi::Shape shape = details::AnyUnsafe::CopyFromAnyViewAfterCheck<ffi::Shape>(value);
node.Set("type", ffi::StaticTypeKey::kTVMFFIShape);
node.Set("data", Array<int64_t>(shape->data, shape->data + shape->size));
break;
}
default: {
if (value.type_index() >= TypeIndex::kTVMFFIStaticObjectBegin) {
// serialize type key since type index is runtime dependent
Expand Down Expand Up @@ -157,10 +164,10 @@ class ObjectGraphSerializer {

// create the data for the object, if the type has a custom data to json function,
// use it. otherwise, we go over the fields and create the data.
json::Object CreateObjectData(const Any& value) {
json::Value CreateObjectData(const Any& value) {
static reflection::TypeAttrColumn data_to_json = reflection::TypeAttrColumn("__data_to_json__");
if (data_to_json[value.type_index()] != nullptr) {
return data_to_json[value.type_index()].cast<Function>()(value).cast<json::Object>();
return data_to_json[value.type_index()].cast<Function>()(value);
}
// NOTE: invariant: lhs and rhs are already the same type
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(value.type_index());
Expand Down Expand Up @@ -286,6 +293,10 @@ class ObjectGraphDeserializer {
case TypeIndex::kTVMFFIArray: {
return DecodeArrayData(node["data"].cast<json::Array>());
}
case TypeIndex::kTVMFFIShape: {
Array<int64_t> data = node["data"].cast<Array<int64_t>>();
return ffi::Shape(data);
}
default: {
return DecodeObjectData(type_index, node["data"]);
}
Expand Down
18 changes: 18 additions & 0 deletions ffi/tests/cpp/extra/test_serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <gtest/gtest.h>
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/shape.h>
#include <tvm/ffi/dtype.h>
#include <tvm/ffi/extra/serialization.h>
#include <tvm/ffi/extra/structural_equal.h>
Expand Down Expand Up @@ -271,6 +272,23 @@ TEST(Serialization, Maps) {
EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_duplicated), duplicated_map));
}

TEST(Serialization, Shapes) {
Shape empty_shape;

json::Object expected_empty_shape = json::Object{
{"root_index", 0},
{"nodes", json::Array{json::Object{{"type", "ffi.Shape"}, {"data", json::Array{}}}}}};
EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_shape), expected_empty_shape));
EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty_shape), empty_shape));

Shape shape({1, 2, 3});
json::Object expected_shape = json::Object{
{"root_index", 0},
{"nodes", json::Array{json::Object{{"type", "ffi.Shape"}, {"data", json::Array{1, 2, 3}}}}}};
EXPECT_TRUE(StructuralEqual()(ToJSONGraph(shape), expected_shape));
EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_shape), shape));
}

TEST(Serialization, TestObjectVar) {
TVar x = TVar("x");
json::Object expected_x = json::Object{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/reflection/creator.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>

Expand Down Expand Up @@ -159,4 +160,9 @@ TEST(Reflection, FuncRegister) {
EXPECT_EQ(fget_value(a).cast<int>(), 12);
}

TEST(Reflection, ObjectCreator) {
namespace refl = tvm::ffi::reflection;
refl::ObjectCreator creator("test.Int");
EXPECT_EQ(creator(Map<String, Any>({{"value", 1}})).cast<TInt>()->value, 1);
}
} // namespace
1 change: 1 addition & 0 deletions ffi/tests/cpp/testing_object.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class TIntObj : public TNumberObj {
public:
int64_t value;

TIntObj() = default;
TIntObj(int64_t value) : value(value) {}

int64_t GetValue() const { return value; }
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/env_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/node/reflection.h>
#include <tvm/node/node.h>

#include <string>
#include <utility>
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/instrument.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
#include <tvm/node/reflection.h>
#include <tvm/node/node.h>

#include <utility>
#include <vector>
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#define TVM_IR_TRANSFORM_H_

#include <tvm/ffi/container/array.h>
#include <tvm/ffi/reflection/creator.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
#include <tvm/ir/diagnostic.h>
Expand Down Expand Up @@ -244,11 +245,10 @@ class PassContext : public ObjectRef {
// NOTE: we could further update the function later.
if constexpr (std::is_base_of_v<ObjectRef, ValueType>) {
int32_t tindex = ffi::TypeToRuntimeTypeIndex<ValueType>::v();
auto* reflection = ReflectionVTable::Global();
auto type_key = ffi::TypeIndexToTypeKey(tindex);
auto legalization = [=](ffi::Any value) -> ffi::Any {
if (auto opt_map = value.try_cast<Map<String, ffi::Any>>()) {
return reflection->CreateObject(type_key, opt_map.value());
return ffi::reflection::ObjectCreator(type_key)(opt_map.value());
} else {
auto opt_val = value.try_cast<ValueType>();
if (!opt_val.has_value()) {
Expand Down
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/arg_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/module.h>
#include <tvm/node/node.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/object.h>
#include <tvm/tir/function.h>
Expand Down
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
#include <tvm/ir/module.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/object.h>
#include <tvm/target/target.h>
Expand Down
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/cost_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include <tvm/meta_schedule/arg_info.h>
#include <tvm/meta_schedule/measure_candidate.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/object.h>
#include <tvm/tir/schedule/schedule.h>

Expand Down
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include <tvm/ir/expr.h>
#include <tvm/ir/module.h>
#include <tvm/meta_schedule/arg_info.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/object.h>
#include <tvm/target/target.h>
#include <tvm/tir/schedule/schedule.h>
Expand Down
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/extracted_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
#include <tvm/ir/module.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/object.h>
#include <tvm/target/target.h>

Expand Down
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/feature_extractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
#include <tvm/meta_schedule/measure_candidate.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/object.h>

Expand Down
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/measure_callback.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/search_strategy.h>
#include <tvm/meta_schedule/tune_context.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/object.h>

namespace tvm {
Expand Down
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/measure_candidate.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/meta_schedule/arg_info.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/object.h>
#include <tvm/tir/schedule/schedule.h>

Expand Down
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <tvm/ffi/function.h>
#include <tvm/ffi/optional.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/object.h>
#include <tvm/support/random_engine.h>
#include <tvm/tir/schedule/schedule.h>
Expand Down
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/object.h>
#include <tvm/tir/schedule/schedule.h>

Expand Down
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
#include <tvm/ir/module.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/object.h>
#include <tvm/target/target.h>

Expand Down
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include <tvm/ffi/string.h>
#include <tvm/ir/expr.h>
#include <tvm/meta_schedule/arg_info.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/object.h>

namespace tvm {
Expand Down
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
#include <tvm/ir/expr.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/object.h>
#include <tvm/tir/schedule/schedule.h>

Expand Down
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/search_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include <tvm/meta_schedule/database.h>
#include <tvm/meta_schedule/measure_candidate.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/object.h>
#include <tvm/tir/schedule/schedule.h>

Expand Down
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/space_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include <tvm/meta_schedule/mutator.h>
#include <tvm/meta_schedule/postproc.h>
#include <tvm/meta_schedule/schedule_rule.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/object.h>
#include <tvm/target/target.h>
#include <tvm/tir/schedule/schedule.h>
Expand Down
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include <tvm/meta_schedule/measure_callback.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/tune_context.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/object.h>
#include <tvm/support/random_engine.h>

Expand Down
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/search_strategy.h>
#include <tvm/meta_schedule/space_generator.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/object.h>
#include <tvm/support/random_engine.h>
#include <tvm/target/target.h>
Expand Down
1 change: 0 additions & 1 deletion include/tvm/node/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
#define TVM_NODE_NODE_H_

#include <tvm/ffi/memory.h>
#include <tvm/node/reflection.h>
#include <tvm/node/repr_printer.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
Expand Down
Loading
Loading