Skip to content

Commit c7a144e

Browse files
sungggjunrushao
authored andcommitted
[Runtime] Make CSourceModule and StaticLibraryModule Binary Serializable (apache#15693)
make csource module and static libary module binary serializable
1 parent 32ed4f0 commit c7a144e

File tree

5 files changed

+76
-12
lines changed

5 files changed

+76
-12
lines changed

src/runtime/static_library.cc

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,39 @@ class StaticLibraryNode final : public runtime::ModuleNode {
5656
}
5757
}
5858

59+
void SaveToBinary(dmlc::Stream* stream) final {
60+
stream->Write(data_);
61+
std::vector<std::string> func_names;
62+
for (const auto func_name : func_names_) func_names.push_back(func_name);
63+
stream->Write(func_names);
64+
}
65+
66+
static Module LoadFromBinary(void* strm) {
67+
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
68+
auto n = make_object<StaticLibraryNode>();
69+
// load data
70+
std::string data;
71+
ICHECK(stream->Read(&data)) << "Loading data failed";
72+
n->data_ = std::move(data);
73+
74+
// load func names
75+
std::vector<std::string> func_names;
76+
ICHECK(stream->Read(&func_names)) << "Loading func names failed";
77+
for (auto func_name : func_names) n->func_names_.push_back(String(func_name));
78+
79+
return Module(n);
80+
}
81+
5982
void SaveToFile(const String& file_name, const String& format) final {
6083
VLOG(0) << "Saving static library of " << data_.size() << " bytes implementing " << FuncNames()
6184
<< " to '" << file_name << "'";
6285
SaveBinaryToFile(file_name, data_);
6386
}
6487

65-
// TODO(tvm-team): Make this module serializable
6688
/*! \brief Get the property of the runtime module .*/
67-
int GetPropertyMask() const override { return ModulePropertyMask::kDSOExportable; }
89+
int GetPropertyMask() const override {
90+
return runtime::ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kDSOExportable;
91+
}
6892

6993
bool ImplementsFunction(const String& name, bool query_imports) final {
7094
return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end();
@@ -103,6 +127,8 @@ Module LoadStaticLibrary(const std::string& filename, Array<String> func_names)
103127
}
104128

105129
TVM_REGISTER_GLOBAL("runtime.ModuleLoadStaticLibrary").set_body_typed(LoadStaticLibrary);
130+
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_static_library")
131+
.set_body_typed(StaticLibraryNode::LoadFromBinary);
106132

107133
} // namespace runtime
108134
} // namespace tvm

src/target/codegen.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,7 @@ class ModuleSerializer {
8383
// we will not produce import_tree_.
8484
bool has_import_tree = true;
8585

86-
if (mod_->IsDSOExportable()) {
87-
ICHECK(export_dso) << "`export_dso` should be enabled for DSOExportable modules";
86+
if (export_dso) {
8887
has_import_tree = !mod_->imports().empty();
8988
}
9089

src/target/source/source_module.cc

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,39 @@ class CSourceModuleNode : public runtime::ModuleNode {
119119

120120
String GetFormat() override { return fmt_; }
121121

122+
void SaveToBinary(dmlc::Stream* stream) final {
123+
stream->Write(code_);
124+
stream->Write(fmt_);
125+
126+
std::vector<std::string> func_names;
127+
for (const auto func_name : func_names_) func_names.push_back(func_name);
128+
std::vector<std::string> const_vars;
129+
for (auto const_var : const_vars_) const_vars.push_back(const_var);
130+
stream->Write(func_names);
131+
stream->Write(const_vars);
132+
}
133+
134+
static runtime::Module LoadFromBinary(void* strm) {
135+
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
136+
137+
std::string code, fmt;
138+
ICHECK(stream->Read(&code)) << "Loading code failed";
139+
ICHECK(stream->Read(&fmt)) << "Loading format failed";
140+
141+
std::vector<std::string> tmp_func_names, tmp_const_vars;
142+
CHECK(stream->Read(&tmp_func_names)) << "Loading func names failed";
143+
CHECK(stream->Read(&tmp_const_vars)) << "Loading const vars failed";
144+
145+
Array<String> func_names;
146+
for (auto func_name : tmp_func_names) func_names.push_back(String(func_name));
147+
148+
Array<String> const_vars;
149+
for (auto const_var : tmp_const_vars) const_vars.push_back(String(const_var));
150+
151+
auto n = make_object<CSourceModuleNode>(code, fmt, func_names, const_vars);
152+
return runtime::Module(n);
153+
}
154+
122155
void SaveToFile(const String& file_name, const String& format) final {
123156
std::string fmt = GetFileFormat(file_name, format);
124157
std::string meta_file = GetMetaFilePath(file_name);
@@ -130,7 +163,10 @@ class CSourceModuleNode : public runtime::ModuleNode {
130163
}
131164
}
132165

133-
int GetPropertyMask() const override { return runtime::ModulePropertyMask::kDSOExportable; }
166+
int GetPropertyMask() const override {
167+
return runtime::ModulePropertyMask::kBinarySerializable |
168+
runtime::ModulePropertyMask::kDSOExportable;
169+
}
134170

135171
bool ImplementsFunction(const String& name, bool query_imports) final {
136172
return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end();
@@ -151,6 +187,9 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt,
151187
return runtime::Module(n);
152188
}
153189

190+
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_c")
191+
.set_body_typed(CSourceModuleNode::LoadFromBinary);
192+
154193
/*!
155194
* \brief A concrete class to get access to base methods of CodegenSourceBase.
156195
*

tests/python/unittest/test_roundtrip_runtime_module.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@
2525

2626

2727
def test_csource_module():
28-
mod = tvm.runtime._ffi_api.CSourceModuleCreate("", "cc", [], None)
29-
# source module that is not binary serializable.
30-
# Thus, it would raise an error.
31-
assert not mod.is_binary_serializable
32-
with pytest.raises(TVMError):
33-
tvm.ir.load_json(tvm.ir.save_json(mod))
28+
mod = tvm.runtime._ffi_api.CSourceModuleCreate("", "cc", [], [])
29+
assert mod.type_key == "c"
30+
assert mod.is_binary_serializable
31+
new_mod = tvm.ir.load_json(tvm.ir.save_json(mod))
32+
assert new_mod.type_key == "c"
33+
assert new_mod.is_binary_serializable
3434

3535

3636
def test_aot_module():

tests/python/unittest/test_runtime_module_property.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def create_aot_module():
4444
def test_property():
4545
checker(
4646
create_csource_module(),
47-
expected={"is_binary_serializable": False, "is_runnable": False, "is_dso_exportable": True},
47+
expected={"is_binary_serializable": True, "is_runnable": False, "is_dso_exportable": True},
4848
)
4949

5050
checker(

0 commit comments

Comments
 (0)