diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index 5255d3f4b10a..feed44498917 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -37,6 +37,10 @@ #include "../ffi/src/ffi/container.cc" #include "../ffi/src/ffi/dtype.cc" #include "../ffi/src/ffi/error.cc" +#include "../ffi/src/ffi/extra/library_module.cc" +#include "../ffi/src/ffi/extra/library_module_dynamic_lib.cc" +#include "../ffi/src/ffi/extra/library_module_system_lib.cc" +#include "../ffi/src/ffi/extra/module.cc" #include "../ffi/src/ffi/function.cc" #include "../ffi/src/ffi/ndarray.cc" #include "../ffi/src/ffi/object.cc" @@ -44,13 +48,10 @@ #include "../ffi/src/ffi/traceback.cc" #include "../src/runtime/cpu_device_api.cc" #include "../src/runtime/device_api.cc" -#include "../src/runtime/dso_library.cc" #include "../src/runtime/file_utils.cc" -#include "../src/runtime/library_module.cc" #include "../src/runtime/logging.cc" #include "../src/runtime/memory/memory_manager.cc" #include "../src/runtime/minrpc/minrpc_logger.cc" -#include "../src/runtime/module.cc" #include "../src/runtime/ndarray.cc" #include "../src/runtime/profiling.cc" #include "../src/runtime/registry.cc" @@ -62,7 +63,6 @@ #include "../src/runtime/rpc/rpc_server_env.cc" #include "../src/runtime/rpc/rpc_session.cc" #include "../src/runtime/rpc/rpc_socket_impl.cc" -#include "../src/runtime/system_library.cc" #include "../src/runtime/thread_pool.cc" #include "../src/runtime/threading_backend.cc" #include "../src/runtime/workspace_pool.cc" diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index e5a5154acbf2..c4a43dc9f39f 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -147,7 +147,7 @@ RPCEnv::RPCEnv(const std::string& wd) { std::string file_name = this->GetPath(path); file_name = BuildSharedLibrary(file_name); LOG(INFO) << "Load module from " << file_name << " ..."; - return Module::LoadFromFile(file_name, ""); + return ffi::Module::LoadFromFile(file_name); })); ffi::Function::SetGlobal("tvm.rpc.server.download_linked_module", diff --git a/apps/hexagon_launcher/launcher_core.cc b/apps/hexagon_launcher/launcher_core.cc index 56242082cca3..fa2c3d8e3300 100644 --- a/apps/hexagon_launcher/launcher_core.cc +++ b/apps/hexagon_launcher/launcher_core.cc @@ -144,7 +144,7 @@ const tvm::ffi::Function get_runtime_func(const std::string& name) { } const tvm::ffi::Function get_module_func(tvm::runtime::Module module, const std::string& name) { - return module.GetFunction(name, false); + return module->GetFunction(name, false).value_or(tvm::ffi::Function()); } void reset_device_api() { @@ -153,7 +153,7 @@ void reset_device_api() { } tvm::runtime::Module load_module(const std::string& file_name) { - static const tvm::ffi::Function loader = get_runtime_func("runtime.module.loadfile_hexagon"); + static const tvm::ffi::Function loader = get_runtime_func("ffi.Module.load_from_file.hexagon"); tvm::ffi::Any rv = loader(file_name); if (rv.type_code() == kTVMModuleHandle) { ICHECK_EQ(rv.type_code(), kTVMModuleHandle) diff --git a/apps/ios_rpc/tvmrpc/TVMRuntime.mm b/apps/ios_rpc/tvmrpc/TVMRuntime.mm index c6f62515736c..09ee55390959 100644 --- a/apps/ios_rpc/tvmrpc/TVMRuntime.mm +++ b/apps/ios_rpc/tvmrpc/TVMRuntime.mm @@ -33,7 +33,7 @@ #if defined(USE_CUSTOM_DSO_LOADER) && USE_CUSTOM_DSO_LOADER == 1 // internal TVM header to achieve Library class -#include <../../../src/runtime/library_module.h> +#include <../../../ffi/src/ffi/extra/library_module.h> #include #endif @@ -70,7 +70,7 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s NSBundle* bundle = [NSBundle mainBundle]; base = [[bundle privateFrameworksPath] stringByAppendingPathComponent:@"tvm"]; - if (tvm::ffi::Function::GetGlobal("runtime.module.loadfile_dylib_custom")) { + if (tvm::ffi::Function::GetGlobal("ffi.Module.load_from_file.dylib_custom")) { // Custom dso loader is present. Will use it. base = NSTemporaryDirectory(); fmt = "dylib_custom"; @@ -114,11 +114,11 @@ void Init(const std::string& name) { // Add UnsignedDSOLoader plugin in global registry TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_packed("runtime.module.loadfile_dylib_custom", + refl::GlobalDef().def_packed("ffi.Module.load_from_file.dylib_custom", [](ffi::PackedArgs args, ffi::Any* rv) { auto n = make_object(); n->Init(args[0]); - *rv = CreateModuleFromLibrary(n); + *rv = tvm::ffi::CreateLibraryModule(n); }); }); diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt index af9943476e3d..ce4f4d4e208a 100644 --- a/ffi/CMakeLists.txt +++ b/ffi/CMakeLists.txt @@ -69,6 +69,10 @@ if (TVM_FFI_USE_EXTRA_CXX_API) "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_writer.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/serialization.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/reflection_extra.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/module.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_system_lib.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_dynamic_lib.cc" ) endif() diff --git a/ffi/include/tvm/ffi/extra/c_env_api.h b/ffi/include/tvm/ffi/extra/c_env_api.h new file mode 100644 index 000000000000..5d5d908f78ba --- /dev/null +++ b/ffi/include/tvm/ffi/extra/c_env_api.h @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/extra/c_env_api.h + * \brief Extra environment API. + */ +#ifndef TVM_FFI_EXTRA_C_ENV_API_H_ +#define TVM_FFI_EXTRA_C_ENV_API_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/*! + * \brief FFI function to lookup a function from a module's imports. + * + * This is a helper function that is used by generated code. + * + * \param library_ctx The library context module handle. + * \param func_name The name of the function. + * \param out The result function. + * \note The returned function is a weak reference that is cached/owned by the module. + * \return 0 when no error is thrown, -1 when failure happens + */ +TVM_FFI_DLL int TVMFFIEnvLookupFromImports(TVMFFIObjectHandle library_ctx, const char* func_name, + TVMFFIObjectHandle* out); + +/* + * \brief Register a symbol value that will be initialized when a library with the symbol is loaded. + * + * This function can be used to make context functions to be available in the library + * module that wants to avoid an explicit link dependency + * + * \param name The name of the symbol. + * \param symbol The symbol to register. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIEnvRegisterContextSymbol(const char* name, void* symbol); + +/*! + * \brief Register a symbol that will be initialized when a system library is loaded. + * + * \param name The name of the symbol. + * \param symbol The symbol to register. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIEnvRegisterSystemLibSymbol(const char* name, void* symbol); + +#ifdef __cplusplus +} // extern "C" +#endif +#endif // TVM_FFI_EXTRA_C_ENV_API_H_ diff --git a/ffi/include/tvm/ffi/extra/module.h b/ffi/include/tvm/ffi/extra/module.h new file mode 100644 index 000000000000..f220c582a91f --- /dev/null +++ b/ffi/include/tvm/ffi/extra/module.h @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/module.h + * \brief A managed dynamic module in the TVM FFI. + */ +#ifndef TVM_FFI_EXTRA_MODULE_H_ +#define TVM_FFI_EXTRA_MODULE_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +// forward declare Module +class Module; + +/*! + * \brief A module that can dynamically load ffi::Functions or exportable source code. + */ +class TVM_FFI_EXTRA_CXX_API ModuleObj : public Object { + public: + /*! + * \return The per module type key. + * \note This key is used to for serializing custom modules. + */ + virtual const char* kind() const = 0; + /*! + * \brief Get the property mask of the module. + * \return The property mask of the module. + * + * \sa Module::ModulePropertyMask + */ + virtual int GetPropertyMask() const { return 0b000; } + /*! + * \brief Get a ffi::Function from the module. + * \param name The name of the function. + * \return The function. + */ + virtual Optional GetFunction(const String& name) = 0; + /*! + * \brief Returns true if this module has a definition for a function of \p name. + * + * Note that even if this function returns true the corresponding \p GetFunction result + * may be nullptr if the function is not yet callable without further compilation. + * + * The default implementation just checks if \p GetFunction is non-null. + * \param name The name of the function. + * \return True if the module implements the function, false otherwise. + */ + virtual bool ImplementsFunction(const String& name) { return GetFunction(name).defined(); } + /*! + * \brief Write the current module to file with given format (for further compilation). + * + * \param file_name The file to be saved to. + * \param format The format of the file. + * + * \note This function is mainly used by modules that + */ + virtual void WriteToFile(const String& file_name, const String& format) const { + TVM_FFI_THROW(RuntimeError) << "Module[" << kind() << "] does not support WriteToFile"; + } + /*! + * \brief Get the possible write formats of the module, when available. + * \return Possible write formats when available. + */ + virtual Array GetWriteFormats() const { return Array(); } + /*! + * \brief Serialize the the module to bytes. + * \return The serialized module. + */ + virtual Bytes SaveToBytes() const { + TVM_FFI_THROW(RuntimeError) << "Module[" << kind() << "] does not support SaveToBytes"; + TVM_FFI_UNREACHABLE(); + } + /*! + * \brief Get the source code of module, when available. + * \param format Format of the source code, can be empty by default. + * \return Possible source code when available, or empty string if not available. + */ + virtual String InspectSource(const String& format = "") const { return String(); } + /*! + * \brief Import another module. + * \param other The module to import. + */ + virtual void ImportModule(const Module& other); + /*! + * \brief Clear all imported modules. + */ + virtual void ClearImports(); + /*! + * \brief Overloaded function to optionally query from imports. + * \param name The name of the function. + * \param query_imports Whether to query imported modules. + * \return The function. + */ + Optional GetFunction(const String& name, bool query_imports); + /*! + * \brief Overloaded function to optionally query from imports. + * \param name The name of the function. + * \param query_imports Whether to query imported modules. + * \return True if the module implements the function, false otherwise. + */ + bool ImplementsFunction(const String& name, bool query_imports); + /*! + * \brief Get the imports of the module. + * \return The imports of the module. + * \note Note the signature is not part of the public API. + */ + const Array& imports() const { return this->imports_; } + + struct InternalUnsafe; + + static constexpr const int32_t _type_index = TypeIndex::kTVMFFIModule; + static constexpr const char* _type_key = StaticTypeKey::kTVMFFIModule; + static const constexpr bool _type_final = true; + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ModuleObj, Object); + + protected: + friend struct InternalUnsafe; + + /*! + * \brief The modules that this module depends on. + * \note Use ObjectRef to avoid circular dep on Module. + */ + Array imports_; + + private: + /*! + * \brief cache used by TVMFFIModuleLookupFromImports + */ + Map import_lookup_cache_; +}; + +/*! + * \brief Reference to module object. + */ +class Module : public ObjectRef { + public: + /*! + * \brief Property of ffi::Module + */ + enum ModulePropertyMask : int { + /*! + * \brief The module can be serialized to bytes. + * + * This prooperty indicates that module implements SaveToBytes. + * The system also registers a GlobalDef function + * `ffi.Module.load_from_bytes.` with signature (Bytes) -> Module. + */ + kBinarySerializable = 0b001, + /*! + * \brief The module can directly get runnable functions. + * + * This property indicates that module implements GetFunction that returns + * runnable ffi::Functions. + */ + kRunnable = 0b010, + /*! + * \brief The module can be exported to a object file or source file that then be compiled. + * + * This property indicates that module implements WriteToFile with a given format + * that can be queried by GetLibExportFormat. + * + * Examples include modules that can be exported to .o, .cc, .cu files. + * + * Such modules can be exported, compiled and loaded back as a dynamic library module. + */ + kCompilationExportable = 0b100 + }; + + /*! + * \brief Load a module from file. + * \param file_name The name of the host function module. + * \param format The format of the file. + * \note This function won't load the import relationship. + * Re-create import relationship by calling Import. + */ + TVM_FFI_EXTRA_CXX_API static Module LoadFromFile(const String& file_name); + /* + * \brief Query context symbols that is registered via TVMEnvRegisterSymbols. + * \param callback The callback to be called with the symbol name and address. + * \note This helper can be used to implement custom Module that needs to access context symbols. + */ + TVM_FFI_EXTRA_CXX_API static void VisitContextSymbols( + const ffi::TypedFunction& callback); + + TVM_FFI_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Module, ObjectRef, ModuleObj); +}; + +/* + * \brief Symbols for library module. + */ +namespace symbol { +/*! \brief Global variable to store context pointer for a library module. */ +constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi_library_ctx"; +/*! \brief Global variable to store binary data alongside a library module. */ +constexpr const char* tvm_ffi_library_bin = "__tvm_ffi_library_bin"; +/*! \brief Default entry function of a library module. */ +constexpr const char* tvm_ffi_main = "__tvm_ffi_main__"; +} // namespace symbol +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_EXTRA_MODULE_H_ diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index 4b7b56209af5..abf7f489038b 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -52,6 +52,8 @@ struct StaticTypeKey { static constexpr const char* kTVMFFIRawStr = "const char*"; static constexpr const char* kTVMFFIByteArrayPtr = "TVMFFIByteArray*"; static constexpr const char* kTVMFFIObjectRValueRef = "ObjectRValueRef"; + static constexpr const char* kTVMFFISmallStr = "ffi.SmallStr"; + static constexpr const char* kTVMFFISmallBytes = "ffi.SmallBytes"; static constexpr const char* kTVMFFIBytes = "ffi.Bytes"; static constexpr const char* kTVMFFIStr = "ffi.String"; static constexpr const char* kTVMFFIShape = "ffi.Shape"; @@ -60,8 +62,7 @@ struct StaticTypeKey { static constexpr const char* kTVMFFIFunction = "ffi.Function"; static constexpr const char* kTVMFFIArray = "ffi.Array"; static constexpr const char* kTVMFFIMap = "ffi.Map"; - static constexpr const char* kTVMFFISmallStr = "ffi.SmallStr"; - static constexpr const char* kTVMFFISmallBytes = "ffi.SmallBytes"; + static constexpr const char* kTVMFFIModule = "ffi.Module"; }; /*! @@ -671,10 +672,10 @@ struct ObjectPtrEqual { */ #define TVM_FFI_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ TypeName() = default; \ - TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ ObjectName* operator->() const { return static_cast(data_.get()); } \ - using ContainerType = ObjectName; + using ContainerType = ObjectName /* * \brief Define object reference methods that is both not nullable and mutable. @@ -685,11 +686,11 @@ struct ObjectPtrEqual { */ #define TVM_FFI_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ ObjectName* operator->() const { return static_cast(data_.get()); } \ ObjectName* get() const { return operator->(); } \ static constexpr bool _type_is_nullable = false; \ - using ContainerType = ObjectName; + using ContainerType = ObjectName namespace details { template diff --git a/ffi/src/ffi/extra/buffer_stream.h b/ffi/src/ffi/extra/buffer_stream.h new file mode 100644 index 000000000000..f6f162676607 --- /dev/null +++ b/ffi/src/ffi/extra/buffer_stream.h @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file buffer_stream.h + * \brief Internal minimal stream helper to read from a buffer. + */ +#ifndef TVM_FFI_EXTRA_BUFFER_STREAM_H_ +#define TVM_FFI_EXTRA_BUFFER_STREAM_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief Lightweight stream helper to read from a buffer. + */ +class BufferInStream { + public: + /*! + * \brief constructor + * \param p_buffer the head pointer of the memory region. + * \param buffer_size the size of the memorybuffer + */ + BufferInStream(const void* data, size_t size) + : data_(reinterpret_cast(data)), size_(size) {} + /*! + * \brief Reads raw from stream. + * \param ptr pointer to the data to be read + * \param size the size of the data to be read + * \return the number of bytes read + */ + size_t Read(void* ptr, size_t size) { + size_t nread = std::min(size_ - curr_ptr_, size); + if (nread != 0) std::memcpy(ptr, data_ + curr_ptr_, nread); + curr_ptr_ += nread; + return nread; + } + /*! + * \brief Reads arithmetic data from stream in endian-aware manner. + * \param data data to be read + * \tparam T the data type to be read + * \return whether the read was successful + */ + template >> + bool Read(T* data) { + bool ret = Read(static_cast(data), sizeof(T)) == sizeof(T); // NOLINT(*) + if (!TVM_FFI_IO_NO_ENDIAN_SWAP) { + ByteSwap(&data, sizeof(T), 1); + } + return ret; + } + /*! + * \brief Reads an array of data from stream in endian-aware manner. + * \param data data to be read + * \param size the size of the data to be read + * \return whether the read was successful + */ + template >> + bool ReadArray(T* data, size_t size) { + bool ret = + this->Read(static_cast(data), sizeof(T) * size) == sizeof(T) * size; // NOLINT(*) + if (!TVM_FFI_IO_NO_ENDIAN_SWAP) { + ByteSwap(data, sizeof(T), size); + } + return ret; + } + /*! + * \brief Reads a string from stream. + * \param data data to be read + * \return whether the read was successful + */ + bool Read(std::string* data) { + // use uint64_t to ensure platform independent size + uint64_t size = 0; + if (!this->Read(&size)) return false; + data->resize(size); + if (!this->Read(data->data(), size)) return false; + return true; + } + /*! + * \brief Reads a vector of data from stream in endian-aware manner. + * \param data data to be read + * \return whether the read was successful + */ + template >> + bool Read(std::vector* data) { + uint64_t size = 0; + if (!this->Read(&size)) return false; + data->resize(size); + return this->ReadArray(data->data(), size); + } + + private: + /*! \brief in memory buffer */ + const char* data_; + /*! \brief size of the buffer */ + size_t size_; + /*! \brief current pointer */ + size_t curr_ptr_{0}; +}; // class BytesInStream + +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_EXTRA_BUFFER_STREAM_H_ diff --git a/ffi/src/ffi/extra/library_module.cc b/ffi/src/ffi/extra/library_module.cc new file mode 100644 index 000000000000..34286d6d0eb2 --- /dev/null +++ b/ffi/src/ffi/extra/library_module.cc @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/extra/library_module.cc + * + * \brief Library module implementation. + */ +#include +#include +#include + +#include "buffer_stream.h" +#include "module_internal.h" + +namespace tvm { +namespace ffi { + +class LibraryModuleObj final : public ModuleObj { + public: + explicit LibraryModuleObj(ObjectPtr lib) : lib_(lib) {} + + const char* kind() const final { return "library"; } + + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return Module::kBinarySerializable | Module::kRunnable; }; + + Optional GetFunction(const String& name) final { + TVMFFISafeCallType faddr; + faddr = reinterpret_cast(lib_->GetSymbol(name.c_str())); + // ensure the function keeps the Library Module alive + Module self_strong_ref = GetRef(this); + if (faddr != nullptr) { + return ffi::Function::FromPacked([faddr, self_strong_ref](ffi::PackedArgs args, + ffi::Any* rv) { + TVM_FFI_ICHECK_LT(rv->type_index(), ffi::TypeIndex::kTVMFFIStaticObjectBegin); + TVM_FFI_CHECK_SAFE_CALL((*faddr)(nullptr, reinterpret_cast(args.data()), + args.size(), reinterpret_cast(rv))); + }); + } + return std::nullopt; + } + + private: + ObjectPtr lib_; +}; + +Module LoadModuleFromBytes(const std::string& kind, const Bytes& bytes) { + std::string loader_key = "ffi.Module.load_from_bytes." + kind; + const auto floader = tvm::ffi::Function::GetGlobal(loader_key); + if (!floader.has_value()) { + TVM_FFI_THROW(RuntimeError) << "Library binary was created using {" << kind + << "} but a loader of that name is not registered. " + << "Make sure to have runtime that registers " << loader_key; + } + return (*floader)(bytes).cast(); +} + +/*! + * \brief Process libary binary to recover binary-serialized modules + * \param library_bin The binary embedded in the library. + * \param opt_lib The library, can be nullptr in which case we expect to deserialize + * all binary-serialized modules + * \param library_ctx_addr the pointer to library module as ctx addr + * \return the root module + * + */ +Module ProcessLibraryBin(const char* library_bin, ObjectPtr opt_lib, + void** library_ctx_addr = nullptr) { + // Layout of the library binary: + // ... + // key can be: "_lib", or a module kind + // - "_lib" indicate this location places the library module + // - other keys are module kinds + // Import tree structure (CSR structure of child indices): + // = > > + TVM_FFI_ICHECK(library_bin != nullptr); + uint64_t nbytes = 0; + for (size_t i = 0; i < sizeof(nbytes); ++i) { + uint64_t c = library_bin[i]; + nbytes |= (c & 0xffUL) << (i * 8); + } + + BufferInStream stream(library_bin + sizeof(nbytes), static_cast(nbytes)); + std::vector import_tree_indptr; + std::vector import_tree_child_indices; + TVM_FFI_ICHECK(stream.Read(&import_tree_indptr)); + TVM_FFI_ICHECK(stream.Read(&import_tree_child_indices)); + size_t num_modules = import_tree_indptr.size() - 1; + std::vector modules; + modules.reserve(num_modules); + + for (uint64_t i = 0; i < num_modules; ++i) { + std::string kind; + TVM_FFI_ICHECK(stream.Read(&kind)); + // "_lib" serves as a placeholder in the module import tree to indicate where + // to place the DSOModule + if (kind == "_lib") { + TVM_FFI_ICHECK(opt_lib != nullptr) << "_lib is not allowed during module serialization"; + auto lib_mod_ptr = make_object(opt_lib); + if (library_ctx_addr) { + *library_ctx_addr = lib_mod_ptr.get(); + } + modules.emplace_back(Module(lib_mod_ptr)); + } else { + std::string module_bytes; + TVM_FFI_ICHECK(stream.Read(&module_bytes)); + Module m = LoadModuleFromBytes(kind, Bytes(module_bytes)); + modules.emplace_back(m); + } + } + for (size_t i = 0; i < modules.size(); ++i) { + for (size_t j = import_tree_indptr[i]; j < import_tree_indptr[i + 1]; ++j) { + Array* module_imports = ModuleObj::InternalUnsafe::GetImports(modules[i].operator->()); + auto child_index = import_tree_child_indices[j]; + TVM_FFI_ICHECK(child_index < modules.size()); + module_imports->emplace_back(modules[child_index]); + } + } + return modules[0]; +} + +// registry to store context symbols +class ContextSymbolRegistry { + public: + void InitContextSymbols(ObjectPtr lib) { + for (const auto& [name, symbol] : context_symbols_) { + if (void** symbol_addr = reinterpret_cast(lib->GetSymbol(name.c_str()))) { + *symbol_addr = symbol; + } + } + } + + void VisitContextSymbols(const ffi::TypedFunction& callback) { + for (const auto& [name, symbol] : context_symbols_) { + callback(name, symbol); + } + } + + void Register(String name, void* symbol) { context_symbols_.emplace_back(name, symbol); } + + static ContextSymbolRegistry* Global() { + static ContextSymbolRegistry* inst = new ContextSymbolRegistry(); + return inst; + } + + private: + std::vector> context_symbols_; +}; + +void Module::VisitContextSymbols(const ffi::TypedFunction& callback) { + ContextSymbolRegistry::Global()->VisitContextSymbols(callback); +} + +Module CreateLibraryModule(ObjectPtr lib) { + const char* library_bin = + reinterpret_cast(lib->GetSymbol(ffi::symbol::tvm_ffi_library_bin)); + void** library_ctx_addr = + reinterpret_cast(lib->GetSymbol(ffi::symbol::tvm_ffi_library_ctx)); + + ContextSymbolRegistry::Global()->InitContextSymbols(lib); + if (library_bin != nullptr) { + // we have embedded binaries that needs to be deserialized + return ProcessLibraryBin(library_bin, lib, library_ctx_addr); + } else { + // Only have one single DSO Module + auto lib_mod_ptr = make_object(lib); + Module root_mod = Module(lib_mod_ptr); + if (library_ctx_addr) { + *library_ctx_addr = root_mod.operator->(); + } + return root_mod; + } +} + +} // namespace ffi +} // namespace tvm + +int TVMFFIEnvRegisterContextSymbol(const char* name, void* symbol) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::String s_name(name); + tvm::ffi::ContextSymbolRegistry::Global()->Register(s_name, symbol); + TVM_FFI_SAFE_CALL_END(); +} diff --git a/src/runtime/dso_library.cc b/ffi/src/ffi/extra/library_module_dynamic_lib.cc similarity index 51% rename from src/runtime/dso_library.cc rename to ffi/src/ffi/extra/library_module_dynamic_lib.cc index 0d74fc87a0fe..25463a7e5f92 100644 --- a/src/runtime/dso_library.cc +++ b/ffi/src/ffi/extra/library_module_dynamic_lib.cc @@ -18,15 +18,14 @@ */ /*! - * \file dso_libary.cc + * \file library_module_dynamic_lib.cc * \brief Create library module to load from dynamic shared library. */ -#include #include #include -#include +#include -#include "library_module.h" +#include "module_internal.h" #if defined(_WIN32) #include @@ -41,46 +40,21 @@ extern "C" { #endif namespace tvm { -namespace runtime { +namespace ffi { -/*! - * \brief Dynamic shared library object used to load - * and retrieve symbols by name. This is the default - * module TVM uses for host-side AOT compilation. - */ class DSOLibrary final : public Library { public: - ~DSOLibrary(); - /*! - * \brief Initialize by loading and storing - * a handle to the underlying shared library. - * \param name The string name/path to the - * shared library over which to initialize. - */ - void Init(const std::string& name); - /*! - * \brief Returns the symbol address within - * the shared library for a given symbol name. - * \param name The name of the symbol. - * \return The symbol. - */ - void* GetSymbol(const char* name) final; + explicit DSOLibrary(const String& name) { Load(name); } + ~DSOLibrary() { + if (lib_handle_) Unload(); + } + + void* GetSymbol(const char* name) final { return GetSymbol_(name); } private: - /*! \brief Private implementation of symbol lookup. - * Implementation is operating system dependent. - * \param The name of the symbol. - * \return The symbol. - */ + // private system dependent implementation void* GetSymbol_(const char* name); - /*! \brief Implementation of shared library load. - * Implementation is operating system dependent. - * \param The name/path of the shared library. - */ - void Load(const std::string& name); - /*! \brief Implementation of shared library unload. - * Implementation is operating system dependent. - */ + void Load(const String& name); void Unload(); #if defined(_WIN32) @@ -92,25 +66,17 @@ class DSOLibrary final : public Library { #endif }; -DSOLibrary::~DSOLibrary() { - if (lib_handle_) Unload(); -} - -void DSOLibrary::Init(const std::string& name) { Load(name); } - -void* DSOLibrary::GetSymbol(const char* name) { return GetSymbol_(name); } - #if defined(_WIN32) void* DSOLibrary::GetSymbol_(const char* name) { return reinterpret_cast(GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) } -void DSOLibrary::Load(const std::string& name) { +void DSOLibrary::Load(const String& name) { // use wstring version that is needed by LLVM. - std::wstring wname(name.begin(), name.end()); + std::wstring wname(name.data(), name.data() + name.size()); lib_handle_ = LoadLibraryW(wname.c_str()); - ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name; + TVM_FFI_ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name; } void DSOLibrary::Unload() { @@ -120,10 +86,10 @@ void DSOLibrary::Unload() { #else -void DSOLibrary::Load(const std::string& name) { +void DSOLibrary::Load(const String& name) { lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); - ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name << " " - << dlerror(); + TVM_FFI_ICHECK(lib_handle_ != nullptr) + << "Failed to load dynamic shared library " << name << " " << dlerror(); #if defined(__hexagon__) int p; int rc = dlinfo(lib_handle_, RTLD_DI_LOAD_ADDR, &p); @@ -140,21 +106,13 @@ void DSOLibrary::Unload() { dlclose(lib_handle_); lib_handle_ = nullptr; } - #endif -ObjectPtr CreateDSOLibraryObject(std::string library_path) { - auto n = make_object(); - n->Init(library_path); - return n; -} - TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("runtime.module.loadfile_so", [](std::string library_path, std::string) { - ObjectPtr n = CreateDSOLibraryObject(library_path); - return CreateModuleFromLibrary(n); + refl::GlobalDef().def("ffi.Module.load_from_file.so", [](String library_path, String) { + return CreateLibraryModule(make_object(library_path)); }); }); -} // namespace runtime +} // namespace ffi } // namespace tvm diff --git a/src/runtime/system_library.cc b/ffi/src/ffi/extra/library_module_system_lib.cc similarity index 63% rename from src/runtime/system_library.cc rename to ffi/src/ffi/extra/library_module_system_lib.cc index 65df96f96375..64b95a122d56 100644 --- a/src/runtime/system_library.cc +++ b/ffi/src/ffi/extra/library_module_system_lib.cc @@ -21,35 +21,34 @@ * \file system_library.cc * \brief Create library module that directly get symbol from the system lib. */ -#include +#include +#include #include #include -#include +#include #include -#include "library_module.h" +#include "module_internal.h" namespace tvm { -namespace runtime { +namespace ffi { class SystemLibSymbolRegistry { public: void RegisterSymbol(const std::string& name, void* ptr) { - std::lock_guard lock(mutex_); auto it = symbol_table_.find(name); - if (it != symbol_table_.end() && ptr != it->second) { - LOG(WARNING) << "SystemLib symbol " << name << " get overriden to a different address " << ptr - << "->" << it->second; + if (it != symbol_table_.end() && ptr != (*it).second) { + std::cerr << "Warning:SystemLib symbol " << name << " get overriden to a different address " + << ptr << "->" << (*it).second << std::endl; } - symbol_table_[name] = ptr; + symbol_table_.Set(name, ptr); } void* GetSymbol(const char* name) { - std::lock_guard lock(mutex_); auto it = symbol_table_.find(name); if (it != symbol_table_.end()) { - return it->second; + return (*it).second; } else { return nullptr; } @@ -61,19 +60,17 @@ class SystemLibSymbolRegistry { } private: - // Internal mutex - std::mutex mutex_; // Internal symbol table - std::unordered_map symbol_table_; + Map symbol_table_; }; -class SystemLibrary : public Library { +class SystemLibrary final : public Library { public: - explicit SystemLibrary(const std::string& symbol_prefix) : symbol_prefix_(symbol_prefix) {} + explicit SystemLibrary(const String& symbol_prefix) : symbol_prefix_(symbol_prefix) {} void* GetSymbol(const char* name) { if (symbol_prefix_.length() != 0) { - std::string name_with_prefix = symbol_prefix_ + name; + String name_with_prefix = symbol_prefix_ + name; void* symbol = reg_->GetSymbol(name_with_prefix.c_str()); if (symbol != nullptr) return symbol; } @@ -82,19 +79,19 @@ class SystemLibrary : public Library { private: SystemLibSymbolRegistry* reg_ = SystemLibSymbolRegistry::Global(); - std::string symbol_prefix_; + String symbol_prefix_; }; class SystemLibModuleRegistry { public: - runtime::Module GetOrCreateModule(std::string symbol_prefix) { + Module GetOrCreateModule(String symbol_prefix) { std::lock_guard lock(mutex_); auto it = lib_map_.find(symbol_prefix); if (it != lib_map_.end()) { - return it->second; + return (*it).second; } else { - auto mod = CreateModuleFromLibrary(make_object(symbol_prefix)); - lib_map_[symbol_prefix] = mod; + Module mod = CreateLibraryModule(make_object(symbol_prefix)); + lib_map_.Set(symbol_prefix, mod); return mod; } } @@ -107,26 +104,26 @@ class SystemLibModuleRegistry { private: // Internal mutex std::mutex mutex_; + // maps prefix to the library module // we need to make sure each lib map have an unique // copy through out the entire lifetime of the process - // so the cached ffi::Function in the system do not get out dated. - std::unordered_map lib_map_; + Map lib_map_; }; TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_packed("runtime.SystemLib", [](ffi::PackedArgs args, ffi::Any* rv) { - std::string symbol_prefix = ""; + refl::GlobalDef().def_packed("ffi.SystemLib", [](ffi::PackedArgs args, ffi::Any* rv) { + String symbol_prefix = ""; if (args.size() != 0) { - symbol_prefix = args[0].cast(); + symbol_prefix = args[0].cast(); } *rv = SystemLibModuleRegistry::Global()->GetOrCreateModule(symbol_prefix); }); }); -} // namespace runtime +} // namespace ffi } // namespace tvm -int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr) { - tvm::runtime::SystemLibSymbolRegistry::Global()->RegisterSymbol(name, ptr); +int TVMFFIEnvRegisterSystemLibSymbol(const char* name, void* ptr) { + tvm::ffi::SystemLibSymbolRegistry::Global()->RegisterSymbol(name, ptr); return 0; } diff --git a/ffi/src/ffi/extra/module.cc b/ffi/src/ffi/extra/module.cc new file mode 100644 index 000000000000..a7f6d4460079 --- /dev/null +++ b/ffi/src/ffi/extra/module.cc @@ -0,0 +1,139 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include + +#include +#include + +#include "module_internal.h" + +namespace tvm { +namespace ffi { + +Optional ModuleObj::GetFunction(const String& name, bool query_imports) { + if (auto opt_func = this->GetFunction(name)) { + return opt_func; + } + if (query_imports) { + for (const Any& import : imports_) { + if (auto opt_func = import.cast()->GetFunction(name, query_imports)) { + return *opt_func; + } + } + } + return std::nullopt; +} + +void ModuleObj::ImportModule(const Module& other) { + std::unordered_set visited{other.operator->()}; + std::vector stack{other.operator->()}; + while (!stack.empty()) { + const ModuleObj* n = stack.back(); + stack.pop_back(); + for (const Any& m : n->imports_) { + const ModuleObj* next = m.cast(); + if (visited.count(next)) continue; + visited.insert(next); + stack.push_back(next); + } + } + if (visited.count(this)) { + TVM_FFI_THROW(RuntimeError) << "Cyclic dependency detected during import"; + } + imports_.push_back(other); +} + +void ModuleObj::ClearImports() { imports_.clear(); } + +bool ModuleObj::ImplementsFunction(const String& name, bool query_imports) { + if (this->ImplementsFunction(name)) { + return true; + } + if (query_imports) { + for (const Any& import : imports_) { + if (import.cast()->ImplementsFunction(name, query_imports)) { + return true; + } + } + } + return false; +} + +Module Module::LoadFromFile(const String& file_name) { + String format = [&file_name]() -> String { + const char* data = file_name.data(); + for (size_t i = file_name.size(); i > 0; i--) { + if (data[i - 1] == '.') { + return String(data + i, file_name.size() - i); + } + } + TVM_FFI_THROW(RuntimeError) << "Failed to get file format from " << file_name; + TVM_FFI_UNREACHABLE(); + }(); + + if (format == "dll" || format == "dylib" || format == "dso") { + format = "so"; + } + String loader_name = "ffi.Module.load_from_file." + format; + const auto floader = tvm::ffi::Function::GetGlobal(loader_name); + if (!floader.has_value()) { + TVM_FFI_THROW(RuntimeError) << "Loader for `." << format << "` files is not registered," + << " resolved to (" << loader_name << ") in the global registry." + << "Ensure that you have loaded the correct runtime code, and" + << "that you are on the correct hardware architecture."; + } + return (*floader)(file_name, format).cast(); +} + +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + ModuleObj::InternalUnsafe::RegisterReflection(); + + refl::GlobalDef() + .def("ffi.ModuleLoadFromFile", &Module::LoadFromFile) + .def_method("ffi.ModuleImplementsFunction", + [](Module mod, String name, bool query_imports) { + return mod->ImplementsFunction(name, query_imports); + }) + .def_method("ffi.ModuleGetFunction", + [](Module mod, String name, bool query_imports) { + return mod->GetFunction(name, query_imports); + }) + .def_method("ffi.ModuleGetPropertyMask", &ModuleObj::GetPropertyMask) + .def_method("ffi.ModuleInspectSource", &ModuleObj::InspectSource) + .def_method("ffi.ModuleGetKind", [](const Module& mod) -> String { return mod->kind(); }) + .def_method("ffi.ModuleGetWriteFormats", &ModuleObj::GetWriteFormats) + .def_method("ffi.ModuleWriteToFile", &ModuleObj::WriteToFile) + .def_method("ffi.ModuleImportModule", &ModuleObj::ImportModule) + .def_method("ffi.ModuleClearImports", &ModuleObj::ClearImports); +}); +} // namespace ffi +} // namespace tvm + +int TVMFFIEnvLookupFromImports(TVMFFIObjectHandle library_ctx, const char* func_name, + TVMFFIObjectHandle* out) { + TVM_FFI_SAFE_CALL_BEGIN(); + *out = tvm::ffi::ModuleObj::InternalUnsafe::GetFunctionFromImports( + reinterpret_cast(library_ctx), func_name); + TVM_FFI_SAFE_CALL_END(); +} diff --git a/ffi/src/ffi/extra/module_internal.h b/ffi/src/ffi/extra/module_internal.h new file mode 100644 index 000000000000..f43d3a3d2c42 --- /dev/null +++ b/ffi/src/ffi/extra/module_internal.h @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file library_module.h + * \brief Module that builds from a libary of symbols. + */ +#ifndef TVM_FFI_EXTRA_MODULE_INTERNAL_H_ +#define TVM_FFI_EXTRA_MODULE_INTERNAL_H_ + +#include +#include + +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief Library is the common interface + * for storing data in the form of shared libaries. + * + * \sa src/ffi/extra/dso_library.cc + * \sa src/ffi/extra/system_library.cc + */ +class Library : public Object { + public: + // destructor. + virtual ~Library() {} + /*! + * \brief Get the symbol address for a given name. + * \param name The name of the symbol. + * \return The symbol. + */ + virtual void* GetSymbol(const char* name) = 0; + // NOTE: we do not explicitly create an type index and type_key here for libary. + // This is because we do not need dynamic type downcasting and only need to use the refcounting +}; + +struct ModuleObj::InternalUnsafe { + static Array* GetImports(ModuleObj* module) { return &(module->imports_); } + + static void* GetFunctionFromImports(ModuleObj* module, const char* name) { + // backend implementation for TVMFFIEnvLookupFromImports + static std::mutex mutex_; + std::lock_guard lock(mutex_); + String s_name(name); + auto it = module->import_lookup_cache_.find(s_name); + if (it != module->import_lookup_cache_.end()) { + return const_cast((*it).second.operator->()); + } + + auto opt_func = [&]() -> std::optional { + for (const Any& import : module->imports_) { + if (auto opt_func = import.cast()->GetFunction(s_name, true)) { + return *opt_func; + } + } + // try global at last + return tvm::ffi::Function::GetGlobal(s_name); + }(); + if (!opt_func.has_value()) { + TVM_FFI_THROW(RuntimeError) << "Cannot find function " << name + << " in the imported modules or global registry."; + } + module->import_lookup_cache_.Set(s_name, *opt_func); + return const_cast((*opt_func).operator->()); + } + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("imports_", &ModuleObj::imports_); + } +}; + +/*! + * \brief Create a library module from a given library. + * + * \param lib The library. + * + * \return The corresponding loaded module. + */ +Module CreateLibraryModule(ObjectPtr lib); + +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_EXTRA_MODULE_INTERNAL_H_ diff --git a/include/tvm/runtime/c_backend_api.h b/include/tvm/runtime/c_backend_api.h index 0d84b55fe318..e44fe465bc96 100644 --- a/include/tvm/runtime/c_backend_api.h +++ b/include/tvm/runtime/c_backend_api.h @@ -54,7 +54,7 @@ TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, * \param ptr The symbol address. * \return 0 when no error is thrown, -1 when failure happens */ -TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr); +TVM_DLL int TVMFFIEnvRegisterSystemLibSymbol(const char* name, void* ptr); /*! * \brief Backend function to allocate temporal workspace. diff --git a/include/tvm/runtime/disco/builtin.h b/include/tvm/runtime/disco/builtin.h index d594de7247c8..bc0faf2413e5 100644 --- a/include/tvm/runtime/disco/builtin.h +++ b/include/tvm/runtime/disco/builtin.h @@ -62,7 +62,7 @@ inline std::string ReduceKind2String(ReduceKind kind) { * \param device The default device used to initialize the RelaxVM * \return The RelaxVM as a runtime Module */ -TVM_DLL Module LoadVMModule(std::string path, Optional device); +TVM_DLL ffi::Module LoadVMModule(std::string path, Optional device); /*! * \brief Create an uninitialized empty NDArray * \param shape The shape of the NDArray diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 80c03ea75132..f805ec988d37 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -27,251 +27,19 @@ #define TVM_RUNTIME_MODULE_H_ #include +#include +#include #include #include #include #include #include -#include -#include -#include -#include #include -#include namespace tvm { namespace runtime { -/*! - * \brief Property of runtime module - * We classify the property of runtime module into the following categories. - */ -enum ModulePropertyMask : int { - /*! \brief kBinarySerializable - * we can serialize the module to the stream of bytes. CUDA/OpenCL/JSON - * runtime are representative examples. A binary exportable module can be integrated into final - * runtime artifact by being serialized as data into the artifact, then deserialized at runtime. - * This class of modules must implement SaveToBinary, and have a matching deserializer registered - * as 'runtime.module.loadbinary_'. - */ - kBinarySerializable = 0b001, - /*! \brief kRunnable - * we can run the module directly. LLVM/CUDA/JSON runtime, executors (e.g, - * virtual machine) runtimes are runnable. Non-runnable modules, such as CSourceModule, requires a - * few extra steps (e.g,. compilation, link) to make it runnable. - */ - kRunnable = 0b010, - /*! \brief kDSOExportable - * we can export the module as DSO. A DSO exportable module (e.g., a - * CSourceModuleNode of type_key 'c') can be incorporated into the final runtime artifact (ie - * shared library) by compilation and/or linking using the external compiler (llvm, nvcc, etc). - * DSO exportable modules must implement SaveToFile. In general, DSO exportable modules are not - * runnable unless there is a special support like JIT for `LLVMModule`. - */ - kDSOExportable = 0b100 -}; - -class ModuleNode; - -/*! - * \brief Module container of TVM. - */ -class Module : public ObjectRef { - public: - Module() {} - // constructor from container. - explicit Module(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief Get packed function from current module by name. - * - * \param name The name of the function. - * \param query_imports Whether also query dependency modules. - * \return The result function. - * This function will return ffi::Function(nullptr) if function do not exist. - * \note Implemented in packed_func.cc - */ - inline ffi::Function GetFunction(const String& name, bool query_imports = false); - // The following functions requires link with runtime. - /*! - * \brief Import another module into this module. - * \param other The module to be imported. - * - * \note Cyclic dependency is not allowed among modules, - * An error will be thrown when cyclic dependency is detected. - */ - inline void Import(Module other); - /*! \return internal container */ - inline ModuleNode* operator->(); - /*! \return internal container */ - inline const ModuleNode* operator->() const; - /*! - * \brief Load a module from file. - * \param file_name The name of the host function module. - * \param format The format of the file. - * \note This function won't load the import relationship. - * Re-create import relationship by calling Import. - */ - TVM_DLL static Module LoadFromFile(const String& file_name, const String& format = ""); - // refer to the corresponding container. - using ContainerType = ModuleNode; - friend class ModuleNode; -}; - -/*! - * \brief Base container of module. - * - * Please subclass ModuleNode to create a specific runtime module. - * - * \code - * - * class MyModuleNode : public ModuleNode { - * public: - * // implement the interface - * }; - * - * // use make_object to create a specific - * // instace of MyModuleNode. - * Module CreateMyModule() { - * ObjectPtr n = - * tvm::ffi::make_object(); - * return Module(n); - * } - * - * \endcode - */ -class TVM_DLL ModuleNode : public Object { - public: - /*! \brief virtual destructor */ - virtual ~ModuleNode() = default; - /*! - * \return The per module type key. - * \note This key is used to for serializing custom modules. - */ - virtual const char* type_key() const = 0; - /*! - * \brief Get a ffi::Function from module. - * - * The ffi::Function may not be fully initialized, - * there might still be first time running overhead when - * executing the function on certain devices. - * For benchmarking, use prepare to eliminate - * - * \param name the name of the function. - * \param sptr_to_self The ObjectPtr that points to this module node. - * - * \return ffi::Function(nullptr) when it is not available. - * - * \note The function will always remain valid. - * If the function need resource from the module(e.g. late linking), - * it should capture sptr_to_self. - */ - virtual ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) = 0; - /*! - * \brief Save the module to file. - * \param file_name The file to be saved to. - * \param format The format of the file. - */ - virtual void SaveToFile(const String& file_name, const String& format); - /*! - * \brief Save the module to binary stream. - * \param stream The binary stream to save to. - * \note It is recommended to implement this for device modules, - * but not necessarily host modules. - * We can use this to do AOT loading of bundled device functions. - */ - virtual void SaveToBinary(dmlc::Stream* stream); - /*! - * \brief Get the source code of module, when available. - * \param format Format of the source code, can be empty by default. - * \return Possible source code when available. - */ - virtual String GetSource(const String& format = ""); - /*! - * \brief Get the format of the module, when available. - * \return Possible format when available. - */ - virtual String GetFormat(); - /*! - * \brief Get packed function from current module by name. - * - * \param name The name of the function. - * \param query_imports Whether also query dependency modules. - * \return The result function. - * This function will return ffi::Function(nullptr) if function do not exist. - * \note Implemented in packed_func.cc - */ - ffi::Function GetFunction(const String& name, bool query_imports = false); - /*! - * \brief Import another module into this module. - * \param other The module to be imported. - * - * \note Cyclic dependency is not allowed among modules, - * An error will be thrown when cyclic dependency is detected. - */ - void Import(Module other); - /*! - * \brief Get a function from current environment - * The environment includes all the imports as well as Global functions. - * - * \param name name of the function. - * \return The corresponding function. - */ - const ffi::Function* GetFuncFromEnv(const String& name); - - /*! \brief Clear all imports of the module. */ - void ClearImports() { imports_.clear(); } - - /*! \return The module it imports from */ - const std::vector& imports() const { return imports_; } - - /*! - * \brief Returns bitmap of property. - * By default, none of the property is set. Derived class can override this function and set its - * own property. - */ - virtual int GetPropertyMask() const { return 0b000; } - - /*! \brief Returns true if this module is 'DSO exportable'. */ - bool IsDSOExportable() const { - return (GetPropertyMask() & ModulePropertyMask::kDSOExportable) != 0; - } - - /*! \brief Returns true if this module is 'Binary Serializable'. */ - bool IsBinarySerializable() const { - return (GetPropertyMask() & ModulePropertyMask::kBinarySerializable) != 0; - } - - /*! - * \brief Returns true if this module has a definition for a function of \p name. If - * \p query_imports is true, also search in any imported modules. - * - * Note that even if this function returns true the corresponding \p GetFunction result may be - * nullptr if the function is not yet callable without further compilation. - * - * The default implementation just checkis if \p GetFunction is non-null. - */ - virtual bool ImplementsFunction(const String& name, bool query_imports = false); - - // integration with the existing components. - static constexpr const uint32_t _type_index = ffi::TypeIndex::kTVMFFIModule; - static constexpr const char* _type_key = "runtime.Module"; - // NOTE: ModuleNode can still be sub-classed - // - TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ModuleNode, Object); - - protected: - friend class Module; - friend class ModuleInternal; - /*! \brief The modules this module depend on */ - std::vector imports_; - - private: - /*! \brief Cache used by GetImport */ - std::unordered_map> import_cache_; - std::mutex mutex_; -}; - /*! * \brief Check if runtime module is enabled for target. * \param target The target module name. @@ -279,19 +47,8 @@ class TVM_DLL ModuleNode : public Object { */ TVM_DLL bool RuntimeEnabled(const String& target); -// implementation of Module::GetFunction -inline ffi::Function Module::GetFunction(const String& name, bool query_imports) { - return (*this)->GetFunction(name, query_imports); -} - /*! \brief namespace for constant symbols */ namespace symbol { -/*! \brief Global variable to store context pointer for a library module. */ -constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi_library_ctx"; -/*! \brief Global variable to store binary data alongside a library module. */ -constexpr const char* tvm_ffi_library_bin = "__tvm_ffi_library_bin"; -/*! \brief Placeholder for the module's entry function. */ -constexpr const char* tvm_ffi_main = "__tvm_ffi_main__"; /*! \brief global function to set device */ constexpr const char* tvm_set_device = "__tvm_set_device"; /*! \brief Auxiliary counter to global barrier. */ @@ -300,24 +57,6 @@ constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state"; constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier"; } // namespace symbol -// implementations of inline functions. - -inline void Module::Import(Module other) { return (*this)->Import(other); } - -inline ModuleNode* Module::operator->() { return static_cast(get_mutable()); } - -inline const ModuleNode* Module::operator->() const { - return static_cast(get()); -} - -inline std::ostream& operator<<(std::ostream& out, const Module& module) { - out << "Module(type_key= "; - out << module->type_key(); - out << ")"; - - return out; -} - namespace details { template @@ -366,12 +105,14 @@ struct ModuleVTableEntryHelper { } // namespace runtime } // namespace tvm -#define TVM_MODULE_VTABLE_BEGIN(TypeKey) \ - const char* type_key() const final { return TypeKey; } \ - ffi::Function GetFunction(const String& _name, const ObjectPtr& _self) override { \ - using SelfPtr = std::remove_cv_t; -#define TVM_MODULE_VTABLE_END() \ - return ffi::Function(nullptr); \ +#define TVM_MODULE_VTABLE_BEGIN(TypeKey) \ + const char* kind() const final { return TypeKey; } \ + ::tvm::ffi::Optional<::tvm::ffi::Function> GetFunction(const String& _name) override { \ + using SelfPtr = std::remove_cv_t; \ + ::tvm::ffi::ObjectPtr<::tvm::ffi::Object> _self = \ + ::tvm::ffi::GetObjectPtr<::tvm::ffi::Object>(this); +#define TVM_MODULE_VTABLE_END() \ + return std::nullopt; \ } #define TVM_MODULE_VTABLE_END_WITH_DEFAULT(MemFunc) \ { \ diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h index e512710ea396..9f25b6775c13 100644 --- a/include/tvm/runtime/profiling.h +++ b/include/tvm/runtime/profiling.h @@ -539,8 +539,8 @@ String ShapeString(const std::vector& shape, DLDataType dtype); * and returns performance metrics as a `Map` where * values can be `CountNode`, `DurationNode`, `PercentNode`. */ -ffi::Function ProfileFunction(Module mod, std::string func_name, int device_type, int device_id, - int warmup_iters, Array collectors); +ffi::Function ProfileFunction(ffi::Module mod, std::string func_name, int device_type, + int device_id, int warmup_iters, Array collectors); /*! * \brief Wrap a timer function to measure the time cost of a given packed function. diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h index a84c902b6711..6dfc2b0c50be 100644 --- a/include/tvm/runtime/vm/executable.h +++ b/include/tvm/runtime/vm/executable.h @@ -85,10 +85,10 @@ struct VMFuncInfo { * The executable contains information (e.g. data in different memory regions) * to run in a virtual machine. */ -class VMExecutable : public runtime::ModuleNode { +class VMExecutable : public ffi::ModuleObj { public: /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; }; + int GetPropertyMask() const final { return ffi::Module::kBinarySerializable; }; /*! * \brief Print the detailed statistics of the given code, i.e. number of @@ -121,25 +121,25 @@ class VMExecutable : public runtime::ModuleNode { String AsPython() const; /*! * \brief Write the VMExecutable to the binary stream in serialized form. - * \param stream The binary stream to save the executable to. + * \return The binary bytes that save the executable to. */ - void SaveToBinary(dmlc::Stream* stream) final; + ffi::Bytes SaveToBytes() const final; /*! * \brief Load VMExecutable from the binary stream in serialized form. - * \param stream The binary stream that load the executable from. + * \param bytes The binary bytes that load the executable from. * \return The loaded executable, in the form of a `runtime::Module`. */ - static Module LoadFromBinary(void* stream); + static ffi::Module LoadFromBytes(const ffi::Bytes& bytes); /*! * \brief Write the VMExecutable to the provided path as a file containing its serialized content. * \param file_name The name of the file to write the serialized data to. * \param format The target format of the saved file. */ - void SaveToFile(const String& file_name, const String& format) final; + void WriteToFile(const String& file_name, const String& format) const final; /*! \brief Create a Relax virtual machine and load `this` as the executable. */ - Module VMLoadExecutable() const; + ffi::Module VMLoadExecutable() const; /*! \brief Create a Relax virtual machine with profiler and load `this` as the executable. */ - Module VMProfilerLoadExecutable() const; + ffi::Module VMProfilerLoadExecutable() const; /*! \brief Check if the VMExecutable contains a specific function. */ bool HasFunction(const String& name) const; /*! @@ -147,7 +147,7 @@ class VMExecutable : public runtime::ModuleNode { * \param file_name The path of the file that load the executable from. * \return The loaded executable, in the form of a `runtime::Module`. */ - static Module LoadFromFile(const String& file_name); + static ffi::Module LoadFromFile(const String& file_name); /*! \brief The virtual machine's function table. */ std::vector func_table; @@ -176,22 +176,22 @@ class VMExecutable : public runtime::ModuleNode { * \brief Save the globals. * \param strm The input stream. */ - void SaveGlobalSection(dmlc::Stream* strm); + void SaveGlobalSection(dmlc::Stream* strm) const; /*! * \brief Save the constant pool. * \param strm The input stream. */ - void SaveConstantSection(dmlc::Stream* strm); + void SaveConstantSection(dmlc::Stream* strm) const; /*! * \brief Save the instructions. * \param strm The input stream. */ - void SaveCodeSection(dmlc::Stream* strm); + void SaveCodeSection(dmlc::Stream* strm) const; /*! * \brief Save the packed functions. * \param strm The input stream. */ - void SavePackedFuncNames(dmlc::Stream* strm); + void SavePackedFuncNames(dmlc::Stream* strm) const; /*! * \brief Load the globals. * \param strm The input stream. diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index ed74ba7b7b2a..3a0b7418b946 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -27,6 +27,8 @@ #define TVM_VM_ENABLE_PROFILER 1 #endif +#include + #include #include #include @@ -128,7 +130,7 @@ class VMExtension : public ObjectRef { * multiple threads, or serialize them to disk or over the * wire. */ -class VirtualMachine : public runtime::ModuleNode { +class VirtualMachine : public ffi::ModuleObj { public: /*! * \brief Initialize the virtual machine for a set of devices. diff --git a/include/tvm/target/codegen.h b/include/tvm/target/codegen.h index 54f09a081b93..d92ef674f12e 100644 --- a/include/tvm/target/codegen.h +++ b/include/tvm/target/codegen.h @@ -45,7 +45,7 @@ using ffi::PackedArgs; * \param target The target to be built. * \return The result runtime::Module. */ -runtime::Module Build(IRModule mod, Target target); +ffi::Module Build(IRModule mod, Target target); /*! * \brief Serialize runtime module including its submodules @@ -53,14 +53,14 @@ runtime::Module Build(IRModule mod, Target target); * \param export_dso By default, include the info of DSOExportable modules. If disabled, an error * will be raised when encountering DSO modules. */ -std::string SerializeModuleToBytes(const runtime::Module& mod, bool export_dso = true); +std::string SerializeModuleToBytes(const ffi::Module& mod, bool export_dso = true); /*! * \brief Deserialize runtime module including its submodules * \param blob byte stream, which are generated by `SerializeModuleToBytes`. * \return runtime::Module runtime module constructed from the given stream */ -runtime::Module DeserializeModuleFromBytes(std::string blob); +ffi::Module DeserializeModuleFromBytes(std::string blob); /*! * \brief Pack imported device library to a C file. @@ -73,7 +73,7 @@ runtime::Module DeserializeModuleFromBytes(std::string blob); * \param c_symbol_prefix Optional symbol prefix of the blob symbol. * \return cstr The C string representation of the file. */ -std::string PackImportsToC(const runtime::Module& m, bool system_lib, +std::string PackImportsToC(const ffi::Module& m, bool system_lib, const std::string& c_symbol_prefix = ""); /*! @@ -89,9 +89,9 @@ std::string PackImportsToC(const runtime::Module& m, bool system_lib, * * \return runtime::Module The generated LLVM module. */ -runtime::Module PackImportsToLLVM(const runtime::Module& m, bool system_lib, - const std::string& target_triple, - const std::string& c_symbol_prefix = ""); +ffi::Module PackImportsToLLVM(const ffi::Module& m, bool system_lib, + const std::string& target_triple, + const std::string& c_symbol_prefix = ""); } // namespace codegen } // namespace tvm diff --git a/jvm/core/src/main/java/org/apache/tvm/Module.java b/jvm/core/src/main/java/org/apache/tvm/Module.java index 9fa65054f91f..46a74346760e 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Module.java +++ b/jvm/core/src/main/java/org/apache/tvm/Module.java @@ -35,7 +35,7 @@ protected Map initialValue() { private static Function getApi(String name) { Function func = apiFuncs.get().get(name); if (func == null) { - func = Function.getFunction("runtime." + name); + func = Function.getFunction(name); apiFuncs.get().put(name, func); } return func; @@ -75,7 +75,7 @@ public Function entryFunc() { * @return The result function. */ public Function getFunction(String name, boolean queryImports) { - TVMValue ret = getApi("ModuleGetFunction") + TVMValue ret = getApi("ffi.ModuleGetFunction") .pushArg(this).pushArg(name).pushArg(queryImports ? 1 : 0).invoke(); return ret.asFunction(); } @@ -89,7 +89,7 @@ public Function getFunction(String name) { * @param module The other module. */ public void importModule(Module module) { - getApi("ModuleImport") + getApi("ffi.ModuleImportModule") .pushArg(this).pushArg(module).invoke(); } @@ -98,7 +98,7 @@ public void importModule(Module module) { * @return type key of the module. */ public String typeKey() { - return getApi("ModuleGetTypeKey").pushArg(this).invoke().asString(); + return getApi("ffi.ModuleGetTypeKind").pushArg(this).invoke().asString(); } /** @@ -109,7 +109,7 @@ public String typeKey() { * @return The loaded module */ public static Module load(String path, String fmt) { - TVMValue ret = getApi("ModuleLoadFromFile").pushArg(path).pushArg(fmt).invoke(); + TVMValue ret = getApi("ffi.ModuleLoadFromFile").pushArg(path).pushArg(fmt).invoke(); return ret.asModule(); } @@ -125,7 +125,7 @@ public static Module load(String path) { * @return Whether runtime is enabled. */ public static boolean enabled(String target) { - TVMValue ret = getApi("RuntimeEnabled").pushArg(target).invoke(); + TVMValue ret = getApi("runtime.RuntimeEnabled").pushArg(target).invoke(); return ret.asLong() != 0; } } diff --git a/python/tvm/contrib/hexagon/tools.py b/python/tvm/contrib/hexagon/tools.py index 5ee89713d9a5..f7f22db721ce 100644 --- a/python/tvm/contrib/hexagon/tools.py +++ b/python/tvm/contrib/hexagon/tools.py @@ -404,7 +404,7 @@ def pack_imports( def export_module(module, out_dir, binary_name="test_binary.so"): """Export Hexagon shared object to a file.""" binary_path = pathlib.Path(out_dir) / binary_name - module.save(str(binary_path)) + module.write_to_file(str(binary_path)) return binary_path diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index a38b31c9bb00..c7be2a7ba6f6 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1621,7 +1621,17 @@ def batch_norm( The computed result. """ return _ffi_api.batch_norm( # type: ignore - data, gamma, beta, moving_mean, moving_var, axis, epsilon, center, scale, momentum, training + data, + gamma, + beta, + moving_mean, + moving_var, + axis, + epsilon, + center, + scale, + momentum, + training, ) diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py index f31927e2f1f9..f6db61af61d2 100644 --- a/python/tvm/relax/vm_build.py +++ b/python/tvm/relax/vm_build.py @@ -153,7 +153,7 @@ def _vmlink( tir_mod = _auto_attach_system_lib_prefix(tir_mod, target, system_lib) lib = tvm.tir.build(tir_mod, target=target, pipeline=tir_pipeline) for ext_mod in ext_libs: - if ext_mod.is_device_module: + if ext_mod.is_device_module(): tir_ext_libs.append(ext_mod) else: relax_ext_libs.append(ext_mod) diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index ea78b0d7d418..0bb4e8cb7d29 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -55,7 +55,7 @@ def system_lib(self): -------- tvm.runtime.system_lib """ - return self.get_function("runtime.SystemLib")() + return self.get_function("ffi.SystemLib")() def get_function(self, name): """Get function from the session. @@ -380,7 +380,12 @@ def text_summary(self): return res def request( - self, key, priority=1, session_timeout=0, max_retry=5, session_constructor_args=None + self, + key, + priority=1, + session_timeout=0, + max_retry=5, + session_constructor_args=None, ): """Request a new connection from the tracker. @@ -474,7 +479,12 @@ def request_and_run(self, key, func, priority=1, session_timeout=0, max_retry=2) def connect( - url, port, key="", session_timeout=0, session_constructor_args=None, enable_logging=False + url, + port, + key="", + session_timeout=0, + session_constructor_args=None, + enable_logging=False, ): """Connect to RPC Server diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index bd0d3d8ed869..49449a451a12 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -94,7 +94,7 @@ def __init__(self, dref: DRef, session: "Session") -> None: self.session = session def __getitem__(self, name: str) -> DPackedFunc: - func = self.session._get_cached_method("runtime.ModuleGetFunction") + func = self.session._get_cached_method("ffi.ModuleGetFunction") return DPackedFunc(func(self, name, False), self.session) @@ -328,7 +328,10 @@ def init_ccl(self, ccl: str, *device_ids): self._clear_ipc_memory_pool() def broadcast( - self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None, in_group: bool = True + self, + src: Union[np.ndarray, NDArray], + dst: Optional[DRef] = None, + in_group: bool = True, ) -> DRef: """Broadcast an array to all workers @@ -383,7 +386,10 @@ def broadcast_from_worker0(self, src: DRef, dst: DRef, in_group: bool = True) -> func(src, in_group, dst) def scatter( - self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None, in_group: bool = True + self, + src: Union[np.ndarray, NDArray], + dst: Optional[DRef] = None, + in_group: bool = True, ) -> DRef: """Scatter an array across all workers @@ -540,7 +546,10 @@ class ProcessSession(Session): """A Disco session backed by pipe-based multi-processing.""" def __init__( - self, num_workers: int, num_groups: int = 1, entrypoint: str = "tvm.exec.disco_worker" + self, + num_workers: int, + num_groups: int = 1, + entrypoint: str = "tvm.exec.disco_worker", ) -> None: self.__init_handle_by_constructor__( _ffi_api.SessionProcess, # type: ignore # pylint: disable=no-member @@ -585,7 +594,12 @@ class SocketSession(Session): """A Disco session backed by socket-based multi-node communication.""" def __init__( - self, num_nodes: int, num_workers_per_node: int, num_groups: int, host: str, port: int + self, + num_nodes: int, + num_workers_per_node: int, + num_groups: int, + host: str, + port: int, ) -> None: self.__init_handle_by_constructor__( _ffi_api.SocketSession, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/runtime/executable.py b/python/tvm/runtime/executable.py index b6e13a65a9f2..51f0a772e403 100644 --- a/python/tvm/runtime/executable.py +++ b/python/tvm/runtime/executable.py @@ -91,7 +91,7 @@ def jit( # TODO(tvm-team): Update runtime.Module interface # to query these properties as bitmask. def _not_runnable(x): - return x.type_key in ("c", "static_library") + return x.kind in ("c", "static_library") # pylint:disable = protected-access not_runnable_list = self.mod._collect_from_import_tree(_not_runnable) diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index e645d3a2b6ce..3925c24365d5 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -28,6 +28,7 @@ from tvm.libinfo import find_include_path from . import _ffi_api +from ..ffi import _ffi_api as _mod_ffi_api class BenchmarkResult: @@ -94,10 +95,10 @@ class ModulePropertyMask(object): BINARY_SERIALIZABLE = 0b001 RUNNABLE = 0b010 - DSO_EXPORTABLE = 0b100 + COMPILATION_EXPORTABLE = 0b100 -@tvm.ffi.register_object("runtime.Module") +@tvm.ffi.register_object("ffi.Module") class Module(tvm.ffi.Object): """Runtime Module.""" @@ -121,6 +122,22 @@ def entry_func(self): self._entry = self.get_function("__tvm_ffi_main__") return self._entry + @property + def kind(self): + """Get type key of the module.""" + return _mod_ffi_api.ModuleGetKind(self) + + @property + def imports(self): + """Get imported modules + + Returns + ---------- + modules : list of Module + The module + """ + return self.imports_ + def implements_function(self, name, query_imports=False): """Returns True if the module has a definition for the global function with name. Note that has_function(name) does not imply get_function(name) is non-null since the module @@ -141,7 +158,7 @@ def implements_function(self, name, query_imports=False): b : Bool True if module (or one of its imports) has a definition for name. """ - return _ffi_api.ModuleImplementsFunction(self, name, query_imports) + return _mod_ffi_api.ModuleImplementsFunction(self, name, query_imports) def get_function(self, name, query_imports=False): """Get function from the module. @@ -159,7 +176,7 @@ def get_function(self, name, query_imports=False): f : tvm.runtime.PackedFunc The result function. """ - func = _ffi_api.ModuleGetFunction(self, name, query_imports) + func = _mod_ffi_api.ModuleGetFunction(self, name, query_imports) if func is None: raise AttributeError(f"Module has no function '{name}'") return func @@ -172,7 +189,7 @@ def import_module(self, module): module : tvm.runtime.Module The other module. """ - _ffi_api.ModuleImport(self, module) + _mod_ffi_api.ModuleImportModule(self, module) def __getitem__(self, name): if not isinstance(name, str): @@ -185,17 +202,7 @@ def __call__(self, *args): # pylint: disable=not-callable return self.entry_func(*args) - @property - def type_key(self): - """Get type key of the module.""" - return _ffi_api.ModuleGetTypeKey(self) - - @property - def format(self): - """Get the format of the module.""" - return _ffi_api.ModuleGetFormat(self) - - def get_source(self, fmt=""): + def inspect_source(self, fmt=""): """Get source code from module, if available. Parameters @@ -208,19 +215,11 @@ def get_source(self, fmt=""): source : str The result source code. """ - return _ffi_api.ModuleGetSource(self, fmt) + return _mod_ffi_api.ModuleInspectSource(self, fmt) - @property - def imported_modules(self): - """Get imported modules - - Returns - ---------- - modules : list of Module - The module - """ - nmod = _ffi_api.ModuleImportsSize(self) - return [_ffi_api.ModuleGetImport(self, i) for i in range(nmod)] + def get_write_formats(self): + """Get the format of the module.""" + return _mod_ffi_api.ModuleGetWriteFormats(self) def get_property_mask(self): """Get the runtime module property mask. The mapping is stated in ModulePropertyMask. @@ -230,9 +229,8 @@ def get_property_mask(self): mask : int Bitmask of runtime module property """ - return _ffi_api.ModuleGetPropertyMask(self) + return _mod_ffi_api.ModuleGetPropertyMask(self) - @property def is_binary_serializable(self): """Returns true if module is 'binary serializable', ie can be serialzed into binary stream and loaded back to the runtime module. @@ -244,7 +242,6 @@ def is_binary_serializable(self): """ return (self.get_property_mask() & ModulePropertyMask.BINARY_SERIALIZABLE) != 0 - @property def is_runnable(self): """Returns true if module is 'runnable'. ie can be executed without any extra compilation/linking steps. @@ -256,31 +253,26 @@ def is_runnable(self): """ return (self.get_property_mask() & ModulePropertyMask.RUNNABLE) != 0 - @property def is_device_module(self): - return self.type_key in ["cuda", "opencl", "metal", "hip", "vulkan", "webgpu"] + return self.kind in ["cuda", "opencl", "metal", "hip", "vulkan", "webgpu"] - @property - def is_dso_exportable(self): - """Returns true if module is 'DSO exportable', ie can be included in result of + def is_compilation_exportable(self): + """Returns true if module is 'compilation exportable', ie can be included in result of export_library by the external compiler directly. Returns ------- b : Bool - True if the module is DSO exportable. + True if the module is compilation exportable. """ - return (self.get_property_mask() & ModulePropertyMask.DSO_EXPORTABLE) != 0 + return (self.get_property_mask() & ModulePropertyMask.COMPILATION_EXPORTABLE) != 0 def clear_imports(self): """Remove all imports of the module.""" - _ffi_api.ModuleClearImports(self) + _mod_ffi_api.ModuleClearImports(self) - def save(self, file_name, fmt=""): - """Save the module to file. - - This do not save the dependent device modules. - See also export_shared + def write_to_file(self, file_name, fmt=""): + """Write the current module to file. Parameters ---------- @@ -293,7 +285,7 @@ def save(self, file_name, fmt=""): -------- runtime.Module.export_library : export the module to shared library. """ - _ffi_api.ModuleSaveToFile(self, file_name, fmt) + _mod_ffi_api.ModuleWriteToFile(self, file_name, fmt) def time_evaluator( self, @@ -414,19 +406,19 @@ def _collect_from_import_tree(self, filter_func): while stack: module = stack.pop() assert ( - module.is_dso_exportable or module.is_binary_serializable - ), f"Module {module.type_key} should be either dso exportable or binary serializable." + module.is_compilation_exportable() or module.is_binary_serializable() + ), f"Module {module.kind} should be either dso exportable or binary serializable." if filter_func(module): dso_modules.append(module) - for m in module.imported_modules: + for m in module.imports: if m not in visited: visited.add(m) stack.append(m) return dso_modules def _collect_dso_modules(self): - return self._collect_from_import_tree(lambda m: m.is_dso_exportable) + return self._collect_from_import_tree(lambda m: m.is_compilation_exportable()) def export_library( self, @@ -509,29 +501,24 @@ def export_library( system_lib_prefix = None llvm_target_string = None global_object_format = "o" + + def get_source_format_from_module(module): + for fmt in module.get_write_formats(): + if fmt in ["c", "cc", "cpp", "cu"]: + return fmt + raise ValueError(f"Module {module.kind} does not exporting to c, cc, cpp or cu.") + for index, module in enumerate(modules): if fcompile is not None and hasattr(fcompile, "object_format"): - if module.type_key == "c": - assert module.format in [ - "c", - "cc", - "cpp", - "cu", - ], "The module.format needs to be either c, cc, cpp or cu." - object_format = module.format + if module.kind == "c": + object_format = get_source_format_from_module(module) has_c_module = True else: global_object_format = object_format = fcompile.object_format else: - if module.type_key == "c": - if len(module.format) > 0: - assert module.format in [ - "c", - "cc", - "cpp", - "cu", - ], "The module.format needs to be either c, cc, cpp, or cu." - object_format = module.format + if module.kind == "c": + if len(module.get_write_formats()) > 0: + object_format = get_source_format_from_module(module) else: object_format = "c" if "cc" in kwargs: @@ -539,13 +526,13 @@ def export_library( object_format = "cu" has_c_module = True else: - assert module.is_dso_exportable + assert module.is_compilation_exportable() global_object_format = object_format = "o" path_obj = os.path.join(workspace_dir, f"lib{index}.{object_format}") - module.save(path_obj) + module.write_to_file(path_obj) files.append(path_obj) - if module.type_key == "llvm": + if module.kind == "llvm": is_system_lib = module.get_function("__tvm_is_system_module")() llvm_target_string = module.get_function("_get_target_string")() system_lib_prefix = module.get_function("__tvm_get_system_lib_prefix")() @@ -566,7 +553,7 @@ def export_library( if getattr(fcompile, "need_system_lib", False) and not is_system_lib: raise ValueError(f"{str(fcompile)} need --system-lib option") - if self.imported_modules: + if self.imports: pack_lib_prefix = system_lib_prefix if system_lib_prefix else "" if fpack_imports is not None: @@ -579,7 +566,7 @@ def export_library( m = _ffi_api.ModulePackImportsToLLVM( self, is_system_lib, llvm_target_string, pack_lib_prefix ) - m.save(path_obj) + m.write_to_file(path_obj) files.append(path_obj) else: path_cc = os.path.join(workspace_dir, f"{pack_lib_prefix}devc.c") @@ -625,10 +612,10 @@ def system_lib(symbol_prefix=""): module : runtime.Module The system-wide library module. """ - return _ffi_api.SystemLib(symbol_prefix) + return _mod_ffi_api.SystemLib(symbol_prefix) -def load_module(path, fmt=""): +def load_module(path): """Load module from file. Parameters @@ -636,10 +623,6 @@ def load_module(path, fmt=""): path : str The path to the module file. - fmt : str, optional - The format of the file, if not specified - it will be inferred from suffix of the file. - Returns ------- module : runtime.Module @@ -673,7 +656,7 @@ def load_module(path, fmt=""): _cc.create_shared(path + ".so", files) path += ".so" # Redirect to the load API - return _ffi_api.ModuleLoadFromFile(path, fmt) + return _mod_ffi_api.ModuleLoadFromFile(path) def load_static_library(path, func_names): diff --git a/python/tvm/testing/usmp.py b/python/tvm/testing/usmp.py deleted file mode 100644 index c35ac255c3b1..000000000000 --- a/python/tvm/testing/usmp.py +++ /dev/null @@ -1,39 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" This file contains USMP tests harnesses.""" - -import tvm - - -def is_tvm_backendallocworkspace_calls(mod: tvm.runtime.module) -> bool: - """TVMBackendAllocWorkspace call check. - - This checker checks whether any c-source produced has TVMBackendAllocWorkspace calls. - If USMP is invoked, none of them should have TVMBAW calls - """ - dso_modules = mod._collect_dso_modules() - for dso_mod in dso_modules: - if dso_mod.type_key not in ["c", "llvm"]: - assert ( - False - ), 'Current AoT codegen flow should only produce type "c" or "llvm" runtime modules' - - source = dso_mod.get_source() - if source.count("TVMBackendAllocWorkspace") != 0: - return True - - return False diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc b/src/contrib/msc/framework/tensorrt/codegen.cc index 684abbe38c17..7acd0f215502 100644 --- a/src/contrib/msc/framework/tensorrt/codegen.cc +++ b/src/contrib/msc/framework/tensorrt/codegen.cc @@ -599,10 +599,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ * \param functions The extern functions to be compiled via TensorRT * \return Runtime modules. */ -Array MSCTensorRTCompiler(Array functions, - Map target_option, - Map constant_names) { - Array compiled_functions; +Array MSCTensorRTCompiler(Array functions, + Map target_option, + Map constant_names) { + Array compiled_functions; for (const auto& func : functions) { VLOG(1) << "MSC.TensorRT partition:" << std::endl << func; const auto& name_opt = func->GetAttr(msc_attr::kUnique); @@ -615,9 +615,9 @@ Array MSCTensorRTCompiler(Array functions, serializer.serialize(func); std::string graph_json = serializer.GetJSON(); const auto pf = tvm::ffi::Function::GetGlobalRequired("runtime.msc_tensorrt_runtime_create"); - VLOG(1) << "Creating msc_tensorrt runtime::Module for '" << func_name << "'"; + VLOG(1) << "Creating msc_tensorrt ffi::Module for '" << func_name << "'"; compiled_functions.push_back( - pf(func_name, graph_json, serializer.GetConstantNames()).cast()); + pf(func_name, graph_json, serializer.GetConstantNames()).cast()); } return compiled_functions; } diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 7b0846051609..41a22e4d39d8 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -21,6 +21,7 @@ */ #include #include +#include #include #include #include @@ -46,16 +47,16 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](const Any& object, bool map_free_vars) -> int64_t { return ffi::StructuralHash::Hash(object, map_free_vars); }); - refl::TypeAttrDef() + refl::TypeAttrDef() .def("__data_to_json__", - [](const runtime::ModuleNode* node) { - std::string bytes = codegen::SerializeModuleToBytes(GetRef(node), + [](const ffi::ModuleObj* node) { + std::string bytes = codegen::SerializeModuleToBytes(GetRef(node), /*export_dso*/ false); return ffi::Base64Encode(ffi::Bytes(bytes)); }) .def("__data_from_json__", [](const String& base64_bytes) { Bytes bytes = ffi::Base64Decode(base64_bytes); - runtime::Module rtmod = codegen::DeserializeModuleFromBytes(bytes.operator std::string()); + ffi::Module rtmod = codegen::DeserializeModuleFromBytes(bytes.operator std::string()); return rtmod; }); diff --git a/src/relax/backend/contrib/clml/codegen.cc b/src/relax/backend/contrib/clml/codegen.cc index 146f7b932f9c..b25bfbdb22a7 100644 --- a/src/relax/backend/contrib/clml/codegen.cc +++ b/src/relax/backend/contrib/clml/codegen.cc @@ -311,9 +311,9 @@ void CollectCLMLFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) * \param functions The extern functions to be compiled via OpenCLML * \return Runtime modules. */ -Array OpenCLMLCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +Array OpenCLMLCompiler(Array functions, Map /*unused*/, + Map constant_names) { + Array compiled_functions; for (const auto& func : functions) { VLOG(1) << "OpenCLML partition:" << std::endl << func; OpenCLMLJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); @@ -322,8 +322,8 @@ Array OpenCLMLCompiler(Array functions, Map()); + VLOG(1) << "Creating clml ffi::Module for '" << func_name << "'"; + compiled_functions.push_back(pf(func_name, graph_json, constant_names).cast()); } return compiled_functions; } diff --git a/src/relax/backend/contrib/cublas/codegen.cc b/src/relax/backend/contrib/cublas/codegen.cc index 3f132b024a1b..0cd0150970e6 100644 --- a/src/relax/backend/contrib/cublas/codegen.cc +++ b/src/relax/backend/contrib/cublas/codegen.cc @@ -109,9 +109,9 @@ class CublasJSONSerializer : public JSONSerializer { Map bindings_; }; -Array CublasCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +Array CublasCompiler(Array functions, Map /*unused*/, + Map constant_names) { + Array compiled_functions; for (const auto& func : functions) { CublasJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); @@ -120,7 +120,7 @@ Array CublasCompiler(Array functions, Map()); + compiled_functions.push_back(pf(func_name, graph_json, constant_names).cast()); } return compiled_functions; diff --git a/src/relax/backend/contrib/cudnn/codegen.cc b/src/relax/backend/contrib/cudnn/codegen.cc index b529c6f79692..a0201ccfda77 100644 --- a/src/relax/backend/contrib/cudnn/codegen.cc +++ b/src/relax/backend/contrib/cudnn/codegen.cc @@ -133,9 +133,9 @@ class cuDNNJSONSerializer : public JSONSerializer { Map bindings_; }; -Array cuDNNCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +Array cuDNNCompiler(Array functions, Map /*unused*/, + Map constant_names) { + Array compiled_functions; for (const auto& func : functions) { cuDNNJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); @@ -144,7 +144,7 @@ Array cuDNNCompiler(Array functions, Map()); + compiled_functions.push_back(pf(func_name, graph_json, constant_names).cast()); } return compiled_functions; diff --git a/src/relax/backend/contrib/cutlass/codegen.cc b/src/relax/backend/contrib/cutlass/codegen.cc index b6307af0237b..29ad2de412d8 100644 --- a/src/relax/backend/contrib/cutlass/codegen.cc +++ b/src/relax/backend/contrib/cutlass/codegen.cc @@ -55,7 +55,7 @@ std::string EmitSignature(const std::vector& out, const std::string& fun return code_stream_.str(); } -runtime::Module Finalize(const std::string& code, const Array& func_names) { +ffi::Module Finalize(const std::string& code, const Array& func_names) { ICHECK(!func_names.empty()) << "Should only create CUTLASS CSourceModule if there is at least one CUTLASS partition"; @@ -72,7 +72,7 @@ runtime::Module Finalize(const std::string& code, const Array& func_name VLOG(1) << "Generated CUTLASS code:" << std::endl << code; return pf(default_headers.str() + code, "cu", func_names, /*const_vars=*/Array()) - .cast(); + .cast(); } class CodegenResultNode : public Object { @@ -337,8 +337,7 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, class CutlassModuleCodegen { public: - runtime::Module CreateCSourceModule(Array functions, - const Map& options) { + ffi::Module CreateCSourceModule(Array functions, const Map& options) { std::string headers = ""; std::string code = ""; for (const auto& f : functions) { @@ -373,8 +372,8 @@ class CutlassModuleCodegen { Array func_names_; }; -Array CUTLASSCompiler(Array functions, Map options, - Map /*unused*/) { +Array CUTLASSCompiler(Array functions, Map options, + Map /*unused*/) { const auto tune_func = tvm::ffi::Function::GetGlobal("contrib.cutlass.tune_relax_function"); ICHECK(tune_func.has_value()) << "The packed function contrib.cutlass.tune_relax_function not found, " @@ -386,7 +385,7 @@ Array CUTLASSCompiler(Array functions, Map(); + ffi::Module cutlass_mod = (*pf)(source_mod, options).cast(); return {cutlass_mod}; } diff --git a/src/relax/backend/contrib/dnnl/codegen.cc b/src/relax/backend/contrib/dnnl/codegen.cc index 83cbdd8e2bbc..efa4e1b685c7 100644 --- a/src/relax/backend/contrib/dnnl/codegen.cc +++ b/src/relax/backend/contrib/dnnl/codegen.cc @@ -81,9 +81,9 @@ class DNNLJSONSerializer : public JSONSerializer { Map bindings_; }; -Array DNNLCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +Array DNNLCompiler(Array functions, Map /*unused*/, + Map constant_names) { + Array compiled_functions; for (const auto& func : functions) { DNNLJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); @@ -92,7 +92,7 @@ Array DNNLCompiler(Array functions, Map()); + compiled_functions.push_back(pf(func_name, graph_json, constant_names).cast()); } return compiled_functions; diff --git a/src/relax/backend/contrib/hipblas/codegen.cc b/src/relax/backend/contrib/hipblas/codegen.cc index 761221c88bac..e1104ac3d6c7 100644 --- a/src/relax/backend/contrib/hipblas/codegen.cc +++ b/src/relax/backend/contrib/hipblas/codegen.cc @@ -86,9 +86,9 @@ class HipblasJSONSerializer : public JSONSerializer { Map bindings_; }; -Array HipblasCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +Array HipblasCompiler(Array functions, Map /*unused*/, + Map constant_names) { + Array compiled_functions; for (const auto& func : functions) { HipblasJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); @@ -97,7 +97,7 @@ Array HipblasCompiler(Array functions, Map()); + compiled_functions.push_back(pf(func_name, graph_json, constant_names).cast()); } return compiled_functions; diff --git a/src/relax/backend/contrib/nnapi/codegen.cc b/src/relax/backend/contrib/nnapi/codegen.cc index c62523f5392d..f045e5b9c2c0 100644 --- a/src/relax/backend/contrib/nnapi/codegen.cc +++ b/src/relax/backend/contrib/nnapi/codegen.cc @@ -247,11 +247,11 @@ void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { ExprVisitor::VisitExpr_(call_node); } -Array NNAPICompiler(Array functions, Map /*unused*/, - Map constant_names) { +Array NNAPICompiler(Array functions, Map /*unused*/, + Map constant_names) { VLOG(1) << "NNAPI Compiler"; - Array compiled_functions; + Array compiled_functions; for (const auto& func : functions) { NNAPIJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); serializer.serialize(func); @@ -260,7 +260,7 @@ Array NNAPICompiler(Array functions, Map(); + tvm::ffi::Module mod = result.cast(); compiled_functions.push_back(mod); } diff --git a/src/relax/backend/contrib/tensorrt/codegen.cc b/src/relax/backend/contrib/tensorrt/codegen.cc index 02483abdc3dc..6dd8216469c2 100644 --- a/src/relax/backend/contrib/tensorrt/codegen.cc +++ b/src/relax/backend/contrib/tensorrt/codegen.cc @@ -225,9 +225,9 @@ void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { * \param functions The extern functions to be compiled via TensorRT * \return Runtime modules. */ -Array TensorRTCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +Array TensorRTCompiler(Array functions, Map /*unused*/, + Map constant_names) { + Array compiled_functions; for (const auto& func : functions) { VLOG(1) << "TensorRT partition:" << std::endl << func; TensorRTJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); @@ -237,8 +237,8 @@ Array TensorRTCompiler(Array functions, Map()); + VLOG(1) << "Creating tensorrt ffi::Module for '" << func_name << "'"; + compiled_functions.push_back(pf(func_name, graph_json, constant_names).cast()); } return compiled_functions; } diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 27165db34350..1f9e8c0378a7 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -310,10 +310,10 @@ class CodeGenVM : public ExprFunctor { String sym = op->global_symbol; String fmt = op->attrs.GetAttr(kCSourceFmt).value_or("c"); String code = opt_code.value(); - Module c_source_module = + ffi::Module c_source_module = codegen::CSourceModuleCreate(/*code=*/code, /*fmt=*/fmt, /*func_names=*/{sym}, /*const_vars=*/{}); - builder_->exec()->Import(c_source_module); + builder_->exec()->ImportModule(c_source_module); } builder_->DeclareFunction(op->global_symbol, VMFuncInfo::FuncKind::kPackedFunc); return builder_->GetFunction(op->global_symbol); @@ -441,17 +441,17 @@ TVM_FFI_STATIC_INIT_BLOCK({ * \return The created module. */ void LinkModules(ObjectPtr exec, const Map& params, - const tvm::runtime::Module& lib, const Array& ext_libs) { + const tvm::ffi::Module& lib, const Array& ext_libs) { // query if we need const loader for ext_modules // Wrap all submodules in the initialization wrapper. std::unordered_map> const_vars_by_symbol; - for (tvm::runtime::Module mod : ext_libs) { - auto pf_sym = mod.GetFunction("get_symbol"); - auto pf_var = mod.GetFunction("get_const_vars"); + for (tvm::ffi::Module mod : ext_libs) { + auto pf_sym = mod->GetFunction("get_symbol"); + auto pf_var = mod->GetFunction("get_const_vars"); std::vector symbol_const_vars; - if (pf_sym != nullptr && pf_var != nullptr) { - String symbol = pf_sym().cast(); - Array variables = pf_var().cast>(); + if (pf_sym.has_value() && pf_var.has_value()) { + String symbol = (*pf_sym)().cast(); + Array variables = (*pf_var)().cast>(); for (size_t i = 0; i < variables.size(); i++) { symbol_const_vars.push_back(variables[i].operator std::string()); } @@ -465,18 +465,18 @@ void LinkModules(ObjectPtr exec, const MapImportModule(lib); for (const auto& it : ext_libs) { - const_loader_mod.Import(it); + const_loader_mod->ImportModule(it); } - exec->Import(const_loader_mod); + exec->ImportModule(const_loader_mod); } else { // directly import the ext_modules as we don't need const loader - exec->Import(lib); + exec->ImportModule(lib); for (const auto& it : ext_libs) { - exec->Import(it); + exec->ImportModule(it); } } } @@ -484,14 +484,14 @@ void LinkModules(ObjectPtr exec, const Map lib, Array ext_libs, - Map params) { +ffi::Module VMLink(ExecBuilder builder, Target target, Optional lib, + Array ext_libs, Map params) { ObjectPtr executable = builder->Get(); if (!lib.defined()) { - lib = codegen::CSourceModuleCreate(";", "", Array{}); + lib = codegen::CSourceModuleCreate(";", "c", Array{}); } LinkModules(executable, params, lib.value(), ext_libs); - return Module(executable); + return ffi::Module(executable); } TVM_FFI_STATIC_INIT_BLOCK({ diff --git a/src/relax/backend/vm/exec_builder.cc b/src/relax/backend/vm/exec_builder.cc index bb43b4ef033d..8e229c4fe641 100644 --- a/src/relax/backend/vm/exec_builder.cc +++ b/src/relax/backend/vm/exec_builder.cc @@ -374,7 +374,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](ExecBuilder builder, String value) { return builder->GetFunction(value).data(); }) .def("relax.ExecBuilderGet", [](ExecBuilder builder) { ObjectPtr p_exec = builder->Get(); - return runtime::Module(p_exec); + return ffi::Module(p_exec); }); }); diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index fc7c6b26df10..c1aee73cc258 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -119,8 +119,8 @@ class ConstantFolder : public ExprMutator { // TODO(Hongyi): further check and narrow the scope of foldable function const auto pf = tvm::ffi::Function::GetGlobalRequired("tir.build"); func = WithAttr(func, tvm::attr::kGlobalSymbol, String("tir_function")); - runtime::Module rt_module = pf(func, eval_cpu_target).cast(); - build_func = rt_module.GetFunction("tir_function"); + ffi::Module rt_module = pf(func, eval_cpu_target).cast(); + build_func = rt_module->GetFunction("tir_function"); } catch (const tvm::Error& err) { // build failure may happen in which case we skip DLOG(WARNING) << "Build failure for function " << func << ", Error message: " << err.what(); diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index a4a109cb0e22..0cc0a070aac5 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -80,7 +80,7 @@ class CodeGenRunner : ExprMutator { auto out_mod = builder_->GetContextIRModule(); if (ext_mods.size()) { - if (auto opt_old_ext_mods = mod->GetAttr>(tvm::attr::kExternalMods)) { + if (auto opt_old_ext_mods = mod->GetAttr>(tvm::attr::kExternalMods)) { auto old_ext_mods = opt_old_ext_mods.value(); ext_mods.insert(ext_mods.begin(), old_ext_mods.begin(), old_ext_mods.end()); } @@ -168,7 +168,7 @@ class CodeGenRunner : ExprMutator { } private: - Array InvokeCodegen(IRModule mod, Map target_options) { + Array InvokeCodegen(IRModule mod, Map target_options) { std::unordered_map> target_functions; for (const auto& entry : mod->functions) { @@ -186,7 +186,7 @@ class CodeGenRunner : ExprMutator { }); } - Array ext_mods; + Array ext_mods; for (const auto& [target, functions] : target_functions) { OptionMap options = target_options.Get(target).value_or(OptionMap()); @@ -196,8 +196,8 @@ class CodeGenRunner : ExprMutator { const auto codegen = tvm::ffi::Function::GetGlobal(codegen_name); ICHECK(codegen.has_value()) << "Codegen is not found: " << codegen_name << "\n"; - Array compiled_functions = - (*codegen)(functions, options, constant_names).cast>(); + Array compiled_functions = + (*codegen)(functions, options, constant_names).cast>(); ext_mods.insert(ext_mods.end(), compiled_functions.begin(), compiled_functions.end()); } diff --git a/src/runtime/const_loader_module.cc b/src/runtime/const_loader_module.cc index 27daf7cc3e01..2c02fb556c73 100644 --- a/src/runtime/const_loader_module.cc +++ b/src/runtime/const_loader_module.cc @@ -27,12 +27,13 @@ * code and constants significantly reduces the efforts for handling external * codegen and runtimes. */ +#include #include #include +#include #include #include #include -#include #include #include @@ -44,9 +45,9 @@ namespace runtime { * \brief The const-loader module is designed to manage initialization of the * imported submodules for the C++ runtime. */ -class ConstLoaderModuleNode : public ModuleNode { +class ConstLoaderModuleObj : public ffi::ModuleObj { public: - ConstLoaderModuleNode( + ConstLoaderModuleObj( const std::unordered_map& const_var_ndarray, const std::unordered_map>& const_vars_by_symbol) : const_var_ndarray_(const_var_ndarray), const_vars_by_symbol_(const_vars_by_symbol) { @@ -66,7 +67,7 @@ class ConstLoaderModuleNode : public ModuleNode { } } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + ffi::Optional GetFunction(const String& name) final { VLOG(1) << "ConstLoaderModuleNode::GetFunction(" << name << ")"; // Initialize and memoize the module. // Usually, we have some warmup runs. The module initialization should be @@ -75,9 +76,10 @@ class ConstLoaderModuleNode : public ModuleNode { this->InitSubModule(name); initialized_[name] = true; } + ObjectRef _self = ffi::GetRef(this); if (name == "get_const_var_ndarray") { - return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { + return ffi::Function([_self, this](ffi::PackedArgs args, ffi::Any* rv) { Map ret_map; for (const auto& kv : const_var_ndarray_) { ret_map.Set(kv.first, kv.second); @@ -89,18 +91,18 @@ class ConstLoaderModuleNode : public ModuleNode { // Run the module. // Normally we would only have a limited number of submodules. The runtime // symobl lookup overhead should be minimal. - ICHECK(!this->imports().empty()); - for (Module it : this->imports()) { - ffi::Function pf = it.GetFunction(name); - if (pf != nullptr) return pf; + ICHECK(!this->imports_.empty()); + for (const Any& it : this->imports_) { + ffi::Optional pf = it.cast()->GetFunction(name); + if (pf.has_value()) return pf.value(); } - return ffi::Function(nullptr); + return std::nullopt; } - const char* type_key() const final { return "const_loader"; } + const char* kind() const final { return "const_loader"; } /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; }; + int GetPropertyMask() const final { return ffi::Module::kBinarySerializable; }; /*! * \brief Get the list of constants that is required by the given module. @@ -134,15 +136,14 @@ class ConstLoaderModuleNode : public ModuleNode { * found module accordingly by passing the needed constants into it. */ void InitSubModule(const std::string& symbol) { - ffi::Function init(nullptr); - for (Module it : this->imports()) { + for (const Any& it : this->imports_) { // Get the initialization function from the imported modules. std::string init_name = "__init_" + symbol; - init = it.GetFunction(init_name, false); - if (init != nullptr) { + Optional init = it.cast()->GetFunction(init_name, false); + if (init.has_value()) { auto md = GetRequiredConstants(symbol); // Initialize the module with constants. - int ret = init(md).cast(); + int ret = (*init)(md).cast(); // Report the error if initialization is failed. ICHECK_EQ(ret, 0); break; @@ -150,7 +151,11 @@ class ConstLoaderModuleNode : public ModuleNode { } } - void SaveToBinary(dmlc::Stream* stream) final { + ffi::Bytes SaveToBytes() const final { + std::string bytes_buffer; + dmlc::MemoryStringStream ms(&bytes_buffer); + dmlc::Stream* stream = &ms; + std::vector variables; std::vector const_var_ndarray; for (const auto& it : const_var_ndarray_) { @@ -182,10 +187,12 @@ class ConstLoaderModuleNode : public ModuleNode { for (uint64_t i = 0; i < sz; i++) { stream->Write(const_vars[i]); } + return ffi::Bytes(bytes_buffer); } - static Module LoadFromBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); + static ffi::Module LoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; // Load the variables. std::vector variables; @@ -225,8 +232,8 @@ class ConstLoaderModuleNode : public ModuleNode { const_vars_by_symbol[symbols[i]] = const_vars[i]; } - auto n = make_object(const_var_ndarray, const_vars_by_symbol); - return Module(n); + auto n = make_object(const_var_ndarray, const_vars_by_symbol); + return ffi::Module(n); } private: @@ -241,17 +248,17 @@ class ConstLoaderModuleNode : public ModuleNode { std::unordered_map> const_vars_by_symbol_; }; -Module ConstLoaderModuleCreate( +ffi::Module ConstLoaderModuleCreate( const std::unordered_map& const_var_ndarray, const std::unordered_map>& const_vars_by_symbol) { - auto n = make_object(const_var_ndarray, const_vars_by_symbol); - return Module(n); + auto n = make_object(const_var_ndarray, const_vars_by_symbol); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("runtime.module.loadbinary_const_loader", - ConstLoaderModuleNode::LoadFromBinary); + refl::GlobalDef().def("ffi.Module.load_from_bytes.const_loader", + ConstLoaderModuleObj::LoadFromBytes); }); } // namespace runtime diff --git a/src/runtime/const_loader_module.h b/src/runtime/const_loader_module.h index eb548dfcf370..c093818763d8 100644 --- a/src/runtime/const_loader_module.h +++ b/src/runtime/const_loader_module.h @@ -43,7 +43,7 @@ namespace runtime { * * \return The created ConstLoaderModule. */ -Module ConstLoaderModuleCreate( +ffi::Module ConstLoaderModuleCreate( const std::unordered_map& const_var_ndarray, const std::unordered_map>& const_vars_by_symbol); diff --git a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc index 183d4c0a5b27..3de9e85a57c5 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc @@ -69,7 +69,7 @@ class ACLRuntime : public JSONRuntimeBase { * * \return module type key. */ - const char* type_key() const override { return "arm_compute_lib"; } + const char* kind() const override { return "arm_compute_lib"; } /*! * \brief Initialize runtime. Create ACL layer from JSON @@ -588,18 +588,18 @@ class ACLRuntime : public JSONRuntimeBase { } #endif }; -runtime::Module ACLRuntimeCreate(const String& symbol_name, const String& graph_json, - const Array& const_names) { +ffi::Module ACLRuntimeCreate(const String& symbol_name, const String& graph_json, + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.arm_compute_lib_runtime_create", ACLRuntimeCreate) - .def("runtime.module.loadbinary_arm_compute_lib", - JSONRuntimeBase::LoadFromBinary); + .def("ffi.Module.load_from_bytes.arm_compute_lib", + JSONRuntimeBase::LoadFromBytes); }); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc index 0bd961524e0c..9080eeb9bb34 100644 --- a/src/runtime/contrib/bnns/bnns_json_runtime.cc +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -91,7 +91,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { const Array const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names) {} - const char* type_key() const override { return "bnns_json"; } + const char* kind() const override { return "bnns_json"; } void Init(const Array& consts) override { ICHECK_EQ(consts.size(), const_idx_.size()) @@ -557,17 +557,17 @@ class BNNSJSONRuntime : public JSONRuntimeBase { std::vector tensors_eid_; }; -runtime::Module BNNSJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { +ffi::Module BNNSJSONRuntimeCreate(String symbol_name, String graph_json, + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.BNNSJSONRuntimeCreate", BNNSJSONRuntimeCreate) - .def("runtime.module.loadbinary_bnns_json", BNNSJSONRuntime::LoadFromBinary); + .def("ffi.Module.load_from_bytes.bnns_json", JSONRuntimeBase::LoadFromBytes); }); } // namespace contrib diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index 6b96cbb41bec..9d13e427b24a 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -193,7 +193,7 @@ class CLMLRuntime : public JSONRuntimeBase { * * \return module type key. */ - const char* type_key() const override { return "clml"; } + const char* kind() const override { return "clml"; } /*! * \brief Initialize runtime. Create CLML layer from JSON @@ -1826,17 +1826,17 @@ class CLMLRuntime : public JSONRuntimeBase { std::string clml_symbol; }; -runtime::Module CLMLRuntimeCreate(const String& symbol_name, const String& graph_json, - const Array& const_names) { +ffi::Module CLMLRuntimeCreate(const String& symbol_name, const String& graph_json, + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.clml_runtime_create", CLMLRuntimeCreate) - .def("runtime.module.loadbinary_clml", JSONRuntimeBase::LoadFromBinary); + .def("ffi.Module.load_from_bytes.clml", JSONRuntimeBase::LoadFromBytes); }); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/coreml/coreml_runtime.h b/src/runtime/contrib/coreml/coreml_runtime.h index 5f5eec1d03ca..257b624bbf2b 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.h +++ b/src/runtime/contrib/coreml/coreml_runtime.h @@ -29,6 +29,7 @@ #import #include +#include #include #include @@ -95,7 +96,7 @@ class CoreMLModel { * This runtime can be accessed in various language via * TVM runtime ffi::Function API. */ -class CoreMLRuntime : public ModuleNode { +class CoreMLRuntime : public ffi::ModuleObj { public: /*! * \brief Get member function to front-end. @@ -103,11 +104,11 @@ class CoreMLRuntime : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self); + virtual Optional GetFunction(const String& name); /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } /*! @@ -115,12 +116,12 @@ class CoreMLRuntime : public ModuleNode { * binary stream. * \param stream The binary stream to save to. */ - void SaveToBinary(dmlc::Stream* stream) final; + ffi::Bytes SaveToBytes() const final; /*! * \return The type key of the executor. */ - const char* type_key() const { return "coreml"; } + const char* kind() const { return "coreml"; } /*! * \brief Initialize the coreml runtime with coreml model and context. diff --git a/src/runtime/contrib/coreml/coreml_runtime.mm b/src/runtime/contrib/coreml/coreml_runtime.mm index 6dfd7a67e5b4..8e0b2542b443 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.mm +++ b/src/runtime/contrib/coreml/coreml_runtime.mm @@ -129,8 +129,7 @@ model_ = std::unique_ptr(new CoreMLModel(url)); } -ffi::Function CoreMLRuntime::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +Optional CoreMLRuntime::GetFunction(const String& name) { // Return member functions during query. if (name == "invoke" || name == "run") { return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) { model_->Invoke(); }); @@ -183,14 +182,14 @@ *rv = out; }); } else { - return ffi::Function(); + return std::nullopt; } } -Module CoreMLRuntimeCreate(const std::string& symbol, const std::string& model_path) { +ffi::Module CoreMLRuntimeCreate(const std::string& symbol, const std::string& model_path) { auto exec = make_object(); exec->Init(symbol, model_path); - return Module(exec); + return ffi::Module(exec); } TVM_FFI_STATIC_INIT_BLOCK({ @@ -200,7 +199,10 @@ Module CoreMLRuntimeCreate(const std::string& symbol, const std::string& model_p }); }); -void CoreMLRuntime::SaveToBinary(dmlc::Stream* stream) { +ffi::Bytes CoreMLRuntime::SaveToBytes() const { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; NSURL* url = model_->url_; NSFileWrapper* dirWrapper = [[[NSFileWrapper alloc] initWithURL:url options:0 error:nil] autorelease]; @@ -209,6 +211,7 @@ Module CoreMLRuntimeCreate(const std::string& symbol, const std::string& model_p stream->Write((uint64_t)[dirData length]); stream->Write([dirData bytes], [dirData length]); DLOG(INFO) << "Save " << symbol_ << " (" << [dirData length] << " bytes)"; + return ffi::Bytes(buffer); } /*! @@ -218,8 +221,9 @@ Module CoreMLRuntimeCreate(const std::string& symbol, const std::string& model_p * * \return The created CoreML module. */ -Module CoreMLRuntimeLoadFromBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); +ffi::Module CoreMLRuntimeLoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; NSString* tempBaseDir = NSTemporaryDirectory(); if (tempBaseDir == nil) tempBaseDir = @"/tmp"; @@ -249,12 +253,12 @@ Module CoreMLRuntimeLoadFromBinary(void* strm) { auto exec = make_object(); exec->Init(symbol, [model_path UTF8String]); - return Module(exec); + return ffi::Module(exec); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("runtime.module.loadbinary_coreml", CoreMLRuntimeLoadFromBinary); + refl::GlobalDef().def("ffi.Module.load_from_bytes.coreml", CoreMLRuntimeLoadFromBytes); }); } // namespace runtime diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index 9f2cfaa50698..11fa3b0c4d49 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -49,21 +49,22 @@ class CublasJSONRuntime : public JSONRuntimeBase { void Init(const Array& consts) override {} - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { + ffi::Optional GetFunction(const String& name) override { // JSONRuntimeBase::SetInputOutputBuffers(...) is not thread safe. Since CublasJSONRuntime // can be used by multiple GPUs running on different threads, we avoid using that function // and directly call cuBLAS on the inputs from ffi::PackedArgs. + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (this->symbol_name_ == name) { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(this->initialized_) << "The module has not been initialized"; this->Run(args); }); } else { - return JSONRuntimeBase::GetFunction(name, sptr_to_self); + return JSONRuntimeBase::GetFunction(name); } } - const char* type_key() const override { return "cublas_json"; } // May be overridden + const char* kind() const override { return "cublas_json"; } // May be overridden void Run(ffi::PackedArgs args) { auto* entry_ptr = tvm::contrib::CuBlasLtThreadEntry::ThreadLocal(); @@ -148,18 +149,18 @@ class CublasJSONRuntime : public JSONRuntimeBase { void Run() override { LOG(FATAL) << "Unreachable"; } }; -runtime::Module CublasJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { +ffi::Module CublasJSONRuntimeCreate(String symbol_name, String graph_json, + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.CublasJSONRuntimeCreate", CublasJSONRuntimeCreate) - .def("runtime.module.loadbinary_cublas_json", - JSONRuntimeBase::LoadFromBinary); + .def("ffi.Module.load_from_bytes.cublas_json", + JSONRuntimeBase::LoadFromBytes); }); } // namespace contrib diff --git a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc index d54fad9d99ab..fd4fa68c783c 100644 --- a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc +++ b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc @@ -69,7 +69,7 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { } } - const char* type_key() const override { return "cudnn_json"; } // May be overridden + const char* kind() const override { return "cudnn_json"; } // May be overridden void Run() override { for (const auto& f : op_execs_) { @@ -232,18 +232,18 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { std::vector> op_execs_; }; -runtime::Module cuDNNJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { +ffi::Module cuDNNJSONRuntimeCreate(String symbol_name, String graph_json, + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.cuDNNJSONRuntimeCreate", cuDNNJSONRuntimeCreate) - .def("runtime.module.loadbinary_cudnn_json", - JSONRuntimeBase::LoadFromBinary); + .def("ffi.Module.load_from_bytes.cudnn_json", + JSONRuntimeBase::LoadFromBytes); }); } // namespace contrib diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index 138f41cb7751..686a8048c7b5 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -58,7 +58,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { for (const auto e : outputs_) run_arg_eid_.push_back(EntryID(e)); } - const char* type_key() const override { return "dnnl_json"; } + const char* kind() const override { return "dnnl_json"; } void Init(const Array& consts) override { ICHECK_EQ(consts.size(), const_idx_.size()) @@ -100,7 +100,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } /* Override GetFunction to reimplement Run method */ - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { + ffi::Optional GetFunction(const String& name) override { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (this->symbol_name_ == name) { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(this->initialized_) << "The module has not been initialized"; @@ -111,7 +112,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { Run(args); }); } else { - return JSONRuntimeBase::GetFunction(name, sptr_to_self); + return JSONRuntimeBase::GetFunction(name); } } @@ -922,17 +923,17 @@ class DNNLJSONRuntime : public JSONRuntimeBase { std::vector run_arg_eid_; }; -runtime::Module DNNLJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { +ffi::Module DNNLJSONRuntimeCreate(String symbol_name, String graph_json, + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.DNNLJSONRuntimeCreate", DNNLJSONRuntimeCreate) - .def("runtime.module.loadbinary_dnnl_json", JSONRuntimeBase::LoadFromBinary); + .def("ffi.Module.load_from_bytes.dnnl_json", JSONRuntimeBase::LoadFromBytes); }); } // namespace contrib diff --git a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc index 515eac9489b6..a52da2318b71 100644 --- a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc +++ b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc @@ -63,10 +63,10 @@ void EdgeTPURuntime::Init(const std::string& tflite_model_bytes, Device dev) { device_ = dev; } -Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes, Device dev) { +ffi::Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes, Device dev) { auto exec = make_object(); exec->Init(tflite_model_bytes, dev); - return Module(exec); + return ffi::Module(exec); } TVM_FFI_STATIC_INIT_BLOCK({ diff --git a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc index fbac6f12fea9..5750b91ab4ca 100644 --- a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc +++ b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc @@ -47,21 +47,22 @@ class HipblasJSONRuntime : public JSONRuntimeBase { void Init(const Array& consts) override {} - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { + ffi::Optional GetFunction(const String& name) override { // JSONRuntimeBase::SetInputOutputBuffers(...) is not thread safe. Since HipblasJSONRuntime // can be used by multiple GPUs running on different threads, we avoid using that function // and directly call hipBLAS on the inputs from ffi::PackedArgs. + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (this->symbol_name_ == name) { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(this->initialized_) << "The module has not been initialized"; this->Run(args); }); } else { - return JSONRuntimeBase::GetFunction(name, sptr_to_self); + return JSONRuntimeBase::GetFunction(name); } } - const char* type_key() const override { return "hipblas_json"; } // May be overridden + const char* kind() const override { return "hipblas_json"; } // May be overridden void Run(ffi::PackedArgs args) { auto* entry_ptr = tvm::contrib::HipBlasLtThreadEntry::ThreadLocal(); @@ -134,18 +135,18 @@ class HipblasJSONRuntime : public JSONRuntimeBase { void Run() override { LOG(FATAL) << "Unreachable"; } }; -runtime::Module HipblasJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { +ffi::Module HipblasJSONRuntimeCreate(String symbol_name, String graph_json, + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.HipblasJSONRuntimeCreate", HipblasJSONRuntimeCreate) - .def("runtime.module.loadbinary_hipblas_json", - JSONRuntimeBase::LoadFromBinary); + .def("ffi.Module.load_from_bytes.hipblas_json", + JSONRuntimeBase::LoadFromBytes); }); } // namespace contrib diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index 025e85263ebc..d9e5af60f299 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -25,7 +25,7 @@ #ifndef TVM_RUNTIME_CONTRIB_JSON_JSON_RUNTIME_H_ #define TVM_RUNTIME_CONTRIB_JSON_JSON_RUNTIME_H_ -#include +#include #include #include @@ -47,7 +47,7 @@ namespace json { * \brief A json runtime that executes the serialized JSON format. This runtime * can be extended by user defined runtime for execution. */ -class JSONRuntimeBase : public ModuleNode { +class JSONRuntimeBase : public ffi::ModuleObj { public: JSONRuntimeBase(const std::string& symbol_name, const std::string& graph_json, const Array const_names) @@ -55,13 +55,11 @@ class JSONRuntimeBase : public ModuleNode { LoadGraph(graph_json_); } - ~JSONRuntimeBase() override = default; - - const char* type_key() const override { return "json"; } // May be overridden + const char* kind() const override { return "json"; } // May be overridden /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const override { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } /*! \brief Initialize a specific json runtime. */ @@ -95,7 +93,8 @@ class JSONRuntimeBase : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The packed function. */ - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { + Optional GetFunction(const String& name) override { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (name == "get_symbol") { return ffi::Function( [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->symbol_name_; }); @@ -148,11 +147,14 @@ class JSONRuntimeBase : public ModuleNode { *rv = 0; }); } else { - return ffi::Function(nullptr); + return std::nullopt; } } - void SaveToBinary(dmlc::Stream* stream) override { + ffi::Bytes SaveToBytes() const override { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; // Save the symbol stream->Write(symbol_name_); // Save the graph @@ -163,12 +165,14 @@ class JSONRuntimeBase : public ModuleNode { consts.push_back(it); } stream->Write(consts); + return ffi::Bytes(buffer); } template ::value>::type> - static Module LoadFromBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); + static ffi::Module LoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; std::string symbol; std::string graph_json; std::vector consts; @@ -181,7 +185,7 @@ class JSONRuntimeBase : public ModuleNode { const_names.push_back(it); } auto n = make_object(symbol, graph_json, const_names); - return Module(n); + return ffi::Module(n); } /*! @@ -190,7 +194,7 @@ class JSONRuntimeBase : public ModuleNode { * \param format the format to return. * \return A string of JSON. */ - String GetSource(const String& format = "json") override { return graph_json_; } + String InspectSource(const String& format) const override { return graph_json_; } protected: /*! diff --git a/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc b/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc index d531963ec822..bc1eb77ea18c 100644 --- a/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc +++ b/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc @@ -23,9 +23,9 @@ */ #include +#include #include #include -#include #include #include @@ -155,7 +155,7 @@ hardware and then runs the generated binary on the target hardware. * */ -class MarvellHardwareModuleNode : public ModuleNode { +class MarvellHardwareModuleNode : public ffi::ModuleObj { public: MarvellHardwareModuleNode(const std::string& symbol_name, const std::string& nodes_json, const std::string& bin_code, const int input_count, @@ -200,10 +200,10 @@ class MarvellHardwareModuleNode : public ModuleNode { } } - const char* type_key() const { return "mrvl_hw"; } + const char* kind() const { return "mrvl_hw"; } int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } /*! @@ -212,7 +212,8 @@ class MarvellHardwareModuleNode : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The packed function. */ - virtual ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) { + virtual Optional GetFunction(const String& name) { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (name == "get_symbol") { return ffi::Function( [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->symbol_name_; }); @@ -240,10 +241,13 @@ class MarvellHardwareModuleNode : public ModuleNode { *rv = 0; }); } - return ffi::Function(nullptr); + return std::nullopt; } - virtual void SaveToBinary(dmlc::Stream* stream) { + virtual ffi::Bytes SaveToBytes() const { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; // Save the symbol name and other data and serialize them to // binary format. stream->Write(symbol_name_); @@ -252,10 +256,12 @@ class MarvellHardwareModuleNode : public ModuleNode { stream->Write(num_inputs_); stream->Write(num_outputs_); stream->Write(run_arg.num_batches); + return ffi::Bytes(buffer); } - static Module LoadFromBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); + static ffi::Module LoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; std::string symbol_name; std::string nodes_json; std::string bin_code; @@ -270,7 +276,7 @@ class MarvellHardwareModuleNode : public ModuleNode { ICHECK(stream->Read(&batch_size)) << "Loading batch_size failed"; auto n = make_object(symbol_name, nodes_json, bin_code, num_inputs, num_outputs, batch_size); - return Module(n); + return ffi::Module(n); } /*! @@ -279,7 +285,7 @@ class MarvellHardwareModuleNode : public ModuleNode { * \param format the format to return. * \return A string of JSON. */ - String GetSource(const String& format = "json") override { return nodes_json_; } + String InspectSource(const String& format) const override { return nodes_json_; } protected: std::string symbol_name_; @@ -463,12 +469,12 @@ class MarvellHardwareModuleNode : public ModuleNode { } }; -runtime::Module MarvellHardwareModuleRuntimeCreate(const String& symbol_name, - const String& nodes_json, const String& bin_code, - int num_input, int num_output, int batch_size) { +ffi::Module MarvellHardwareModuleRuntimeCreate(const String& symbol_name, const String& nodes_json, + const String& bin_code, int num_input, + int num_output, int batch_size) { auto n = make_object(symbol_name, nodes_json, bin_code, num_input, num_output, batch_size); - return runtime::Module(n); + return ffi::Module(n); } bool MarvellHardwareModuleNode::initialized_model = false; @@ -481,7 +487,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.mrvl_hw_runtime_create", MarvellHardwareModuleRuntimeCreate) - .def("runtime.module.loadbinary_mrvl_hw", MarvellHardwareModuleNode::LoadFromBinary); + .def("ffi.Module.load_from_bytes.mrvl_hw", MarvellHardwareModuleNode::LoadFromBytes); }); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/mrvl/mrvl_runtime.cc b/src/runtime/contrib/mrvl/mrvl_runtime.cc index b9f9bc960c04..974ca4a69a1f 100644 --- a/src/runtime/contrib/mrvl/mrvl_runtime.cc +++ b/src/runtime/contrib/mrvl/mrvl_runtime.cc @@ -24,9 +24,9 @@ #include #include +#include #include #include -#include #include #include @@ -49,7 +49,7 @@ hardware and then runs the generated binary using the Marvell software simulator * \param bin_code The binary code generated by the Marvell compiler for the subgraph */ -class MarvellSimulatorModuleNode : public ModuleNode { +class MarvellSimulatorModuleNode : public ffi::ModuleObj { public: MarvellSimulatorModuleNode(const std::string& symbol_name, const std::string& nodes_json, const std::string& bin_code) @@ -57,11 +57,11 @@ class MarvellSimulatorModuleNode : public ModuleNode { set_num_inputs_outputs(); } - const char* type_key() const { return "mrvl_sim"; } + const char* kind() const { return "mrvl_sim"; } /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } /*! @@ -70,7 +70,8 @@ class MarvellSimulatorModuleNode : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The packed function. */ - virtual ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) { + virtual Optional GetFunction(const String& name) { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (name == "get_symbol") { return ffi::Function( [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->symbol_name_; }); @@ -83,19 +84,24 @@ class MarvellSimulatorModuleNode : public ModuleNode { *rv = 0; }); } - return ffi::Function(nullptr); + return std::nullopt; } - virtual void SaveToBinary(dmlc::Stream* stream) { + virtual ffi::Bytes SaveToBytes() const { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; // Save the symbol name and other data and serialize them to // binary format. stream->Write(symbol_name_); stream->Write(nodes_json_); stream->Write(bin_code_); + return ffi::Bytes(buffer); } - static Module LoadFromBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); + static ffi::Module LoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; std::string symbol_name; std::string nodes_json; std::string bin_code; @@ -106,7 +112,7 @@ class MarvellSimulatorModuleNode : public ModuleNode { << "Marvell-Compiler-ERROR-Internal::Loading nodes json failed"; ICHECK(stream->Read(&bin_code)) << "Marvell-Compiler-ERROR-Internal::Loading bin code failed"; auto n = make_object(symbol_name, nodes_json, bin_code); - return Module(n); + return ffi::Module(n); } /*! @@ -115,7 +121,7 @@ class MarvellSimulatorModuleNode : public ModuleNode { * \param format the format to return. * \return A string of JSON. */ - String GetSource(const String& format = "json") override { return nodes_json_; } + String InspectSource(const String& format) const override { return nodes_json_; } protected: std::string symbol_name_; @@ -143,18 +149,17 @@ class MarvellSimulatorModuleNode : public ModuleNode { } }; -runtime::Module MarvellSimulatorModuleRuntimeCreate(const String& symbol_name, - const String& nodes_json, - const String& bin_code) { +ffi::Module MarvellSimulatorModuleRuntimeCreate(const String& symbol_name, const String& nodes_json, + const String& bin_code) { auto n = make_object(symbol_name, nodes_json, bin_code); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.mrvl_runtime_create", MarvellSimulatorModuleRuntimeCreate) - .def("runtime.module.loadbinary_mrvl_sim", MarvellSimulatorModuleNode::LoadFromBinary); + .def("ffi.Module.load_from_bytes.mrvl_sim", MarvellSimulatorModuleNode::LoadFromBytes); }); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/msc/tensorrt_runtime.cc b/src/runtime/contrib/msc/tensorrt_runtime.cc index fd65175d6f8e..e19c03d4fda5 100644 --- a/src/runtime/contrib/msc/tensorrt_runtime.cc +++ b/src/runtime/contrib/msc/tensorrt_runtime.cc @@ -65,7 +65,7 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { const Array& const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names) {} - ~MSCTensorRTRuntime() override { + ~MSCTensorRTRuntime() { VLOG(1) << "Destroying MSC TensorRT runtime"; DestroyEngine(); } @@ -75,11 +75,11 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { * * \return module type key. */ - const char* type_key() const final { return "msc_tensorrt"; } + const char* kind() const final { return "msc_tensorrt"; } /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } /*! @@ -343,18 +343,18 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { #endif }; -runtime::Module MSCTensorRTRuntimeCreate(const String& symbol_name, const String& graph_json, - const Array& const_names) { +ffi::Module MSCTensorRTRuntimeCreate(const String& symbol_name, const String& graph_json, + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.msc_tensorrt_runtime_create", MSCTensorRTRuntimeCreate) - .def("runtime.module.loadbinary_msc_tensorrt", - JSONRuntimeBase::LoadFromBinary); + .def("ffi.Module.load_from_bytes.msc_tensorrt", + JSONRuntimeBase::LoadFromBytes); }); } // namespace contrib diff --git a/src/runtime/contrib/nnapi/nnapi_runtime.cc b/src/runtime/contrib/nnapi/nnapi_runtime.cc index 52b4a4711837..71335f3ee287 100644 --- a/src/runtime/contrib/nnapi/nnapi_runtime.cc +++ b/src/runtime/contrib/nnapi/nnapi_runtime.cc @@ -235,17 +235,17 @@ class NNAPIRuntime : public JSONRuntimeBase { #endif // ifdef TVM_GRAPH_EXECUTOR_NNAPI }; -runtime::Module NNAPIRuntimeCreate(const String& symbol_name, const String& graph_json, - const Array& const_names) { +ffi::Module NNAPIRuntimeCreate(const String& symbol_name, const String& graph_json, + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.nnapi_runtime_create", NNAPIRuntimeCreate) - .def("runtime.module.loadbinary_nnapi", JSONRuntimeBase::LoadFromBinary); + .def("ffi.Module.load_from_bytes.nnapi", JSONRuntimeBase::LoadFromBytes); }); } // namespace contrib diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index ba9725d9bb10..ff565444e2b5 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -96,11 +96,11 @@ class TensorRTRuntime : public JSONRuntimeBase { * * \return module type key. */ - const char* type_key() const final { return "tensorrt"; } + const char* kind() const final { return "tensorrt"; } /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } /*! @@ -519,17 +519,17 @@ class TensorRTRuntime : public JSONRuntimeBase { bool use_fp16_; }; -runtime::Module TensorRTRuntimeCreate(const String& symbol_name, const String& graph_json, - const Array& const_names) { +ffi::Module TensorRTRuntimeCreate(const String& symbol_name, const String& graph_json, + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.tensorrt_runtime_create", TensorRTRuntimeCreate) - .def("runtime.module.loadbinary_tensorrt", JSONRuntimeBase::LoadFromBinary); + .def("ffi.Module.load_from_bytes.tensorrt", JSONRuntimeBase::LoadFromBytes); }); } // namespace contrib diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index 0fa3bc2fe64c..c35af35eae13 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -87,6 +87,7 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) { return DataType::Float(16); default: LOG(FATAL) << "tflite data type not support yet: " << dtype; + TVM_FFI_UNREACHABLE(); } } @@ -151,8 +152,8 @@ NDArray TFLiteRuntime::GetOutput(int index) const { return ret; } -ffi::Function TFLiteRuntime::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +ffi::Optional TFLiteRuntime::GetFunction(const String& name) { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); // Return member functions during query. if (name == "set_input") { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { @@ -174,14 +175,14 @@ ffi::Function TFLiteRuntime::GetFunction(const String& name, this->SetNumThreads(num_threads); }); } else { - return ffi::Function(); + return std::nullopt; } } -Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, Device dev) { +ffi::Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, Device dev) { auto exec = make_object(); exec->Init(tflite_model_bytes, dev); - return Module(exec); + return ffi::Module(exec); } TVM_FFI_STATIC_INIT_BLOCK({ diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index 5e8751a01281..396bd01104d5 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -27,8 +27,8 @@ #include #include +#include #include -#include #include #include @@ -46,7 +46,7 @@ namespace runtime { * This runtime can be accessed in various language via * TVM runtime ffi::Function API. */ -class TFLiteRuntime : public ModuleNode { +class TFLiteRuntime : public ffi::ModuleObj { public: /*! * \brief Get member function to front-end. @@ -54,15 +54,15 @@ class TFLiteRuntime : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self); + virtual Optional GetFunction(const String& name); /*! * \return The type key of the executor. */ - const char* type_key() const { return "TFLiteRuntime"; } + const char* kind() const { return "TFLiteRuntime"; } /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const final { return ModulePropertyMask::kRunnable; }; + int GetPropertyMask() const final { return ffi::Module::kRunnable; }; /*! * \brief Invoke the internal tflite interpreter and run the whole model in diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 6b71df928d23..5a4e682da8da 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -46,7 +47,7 @@ namespace runtime { // cuModule is a per-GPU module // The runtime will contain a per-device module table // The modules will be lazily loaded -class CUDAModuleNode : public runtime::ModuleNode { +class CUDAModuleNode : public ffi::ModuleObj { public: explicit CUDAModuleNode(std::string data, std::string fmt, std::unordered_map fmap, @@ -64,16 +65,16 @@ class CUDAModuleNode : public runtime::ModuleNode { } } - const char* type_key() const final { return "cuda"; } + const char* kind() const final { return "cuda"; } /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + Optional GetFunction(const String& name) final; - void SaveToFile(const String& file_name, const String& format) final { + void WriteToFile(const String& file_name, const String& format) const final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); if (fmt == "cu") { @@ -87,13 +88,17 @@ class CUDAModuleNode : public runtime::ModuleNode { } } - void SaveToBinary(dmlc::Stream* stream) final { + ffi::Bytes SaveToBytes() const final { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; stream->Write(fmt_); stream->Write(fmap_); stream->Write(data_); + return ffi::Bytes(buffer); } - String GetSource(const String& format) final { + String InspectSource(const String& format) const final { if (format == fmt_) return data_; if (cuda_source_.length() != 0) { return cuda_source_; @@ -205,7 +210,7 @@ class CUDAWrappedFunc { << " grid=(" << wl.grid_dim(0) << "," << wl.grid_dim(1) << "," << wl.grid_dim(2) << "), " << " block=(" << wl.block_dim(0) << "," << wl.block_dim(1) << "," << wl.block_dim(2) << ")\n"; - std::string cuda = m_->GetSource(""); + std::string cuda = m_->InspectSource(""); if (cuda.length() != 0) { os << "// func_name=" << func_name_ << "\n" << "// CUDA Source\n" @@ -255,8 +260,8 @@ class CUDAPrepGlobalBarrier { mutable std::array pcache_; }; -ffi::Function CUDAModuleNode::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +Optional CUDAModuleNode::GetFunction(const String& name) { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); if (name == symbol::tvm_prepare_global_barrier) { return ffi::Function(CUDAPrepGlobalBarrier(this, sptr_to_self)); @@ -269,15 +274,15 @@ ffi::Function CUDAModuleNode::GetFunction(const String& name, return PackFuncVoidAddr(f, info.arg_types, info.arg_extra_tags); } -Module CUDAModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, - std::string cuda_source) { +ffi::Module CUDAModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string cuda_source) { auto n = make_object(data, fmt, fmap, cuda_source); - return Module(n); + return ffi::Module(n); } // Load module from module. -Module CUDAModuleLoadFile(const std::string& file_name, const String& format) { +ffi::Module CUDAModuleLoadFile(const std::string& file_name, const String& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -287,8 +292,9 @@ Module CUDAModuleLoadFile(const std::string& file_name, const String& format) { return CUDAModuleCreate(data, fmt, fmap, std::string()); } -Module CUDAModuleLoadBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); +ffi::Module CUDAModuleLoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; std::string data; std::unordered_map fmap; std::string fmt; @@ -301,9 +307,9 @@ Module CUDAModuleLoadBinary(void* strm) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("runtime.module.loadfile_cubin", CUDAModuleLoadFile) - .def("runtime.module.loadfile_ptx", CUDAModuleLoadFile) - .def("runtime.module.loadbinary_cuda", CUDAModuleLoadBinary); + .def("ffi.Module.load_from_file.cuda", CUDAModuleLoadFile) + .def("ffi.Module.load_from_file.ptx", CUDAModuleLoadFile) + .def("ffi.Module.load_from_bytes.cuda", CUDAModuleLoadFromBytes); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/cuda_module.h b/src/runtime/cuda/cuda_module.h index e65c5fe60811..b92dbe1cc683 100644 --- a/src/runtime/cuda/cuda_module.h +++ b/src/runtime/cuda/cuda_module.h @@ -47,9 +47,9 @@ static constexpr const int kMaxNumGPUs = 32; * \param fmap The map function information map of each function. * \param cuda_source Optional, cuda source file */ -Module CUDAModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, - std::string cuda_source); +ffi::Module CUDAModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string cuda_source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_CUDA_CUDA_MODULE_H_ diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc index 947d8884a59c..ae85f9ce5384 100644 --- a/src/runtime/device_api.cc +++ b/src/runtime/device_api.cc @@ -22,6 +22,7 @@ * \brief Device specific implementations */ #include +#include #include #include #include @@ -235,10 +236,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ using namespace tvm::runtime; int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFFIObjectHandle* func) { - TVM_FFI_SAFE_CALL_BEGIN(); - *func = const_cast( - static_cast(mod_node)->GetFuncFromEnv(func_name)->get()); - TVM_FFI_SAFE_CALL_END(); + return TVMFFIEnvLookupFromImports(mod_node, func_name, func); } void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint, diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc index 126a593e5173..b650b143e401 100644 --- a/src/runtime/disco/builtin.cc +++ b/src/runtime/disco/builtin.cc @@ -34,36 +34,39 @@ namespace runtime { class DSOLibraryCache { public: - Module Open(const std::string& library_path) { + ffi::Module Open(const std::string& library_path) { std::lock_guard lock(mutex_); - Module& lib = cache_[library_path]; - if (!lib.defined()) { - lib = Module::LoadFromFile(library_path, ""); + auto it = cache_.find(library_path); + if (it == cache_.end()) { + ffi::Module lib = ffi::Module::LoadFromFile(library_path); + cache_.emplace(library_path, lib); + return lib; } - return lib; + return it->second; } - std::unordered_map cache_; + std::unordered_map cache_; std::mutex mutex_; }; -Module LoadVMModule(std::string path, Optional device) { +ffi::Module LoadVMModule(std::string path, Optional device) { static DSOLibraryCache cache; - Module dso_mod = cache.Open(path); + ffi::Module dso_mod = cache.Open(path); Device dev = UseDefaultDeviceIfNone(device); - ffi::Function vm_load_executable = dso_mod.GetFunction("vm_load_executable"); - if (vm_load_executable == nullptr) { + Optional vm_load_executable = dso_mod->GetFunction("vm_load_executable"); + if (!vm_load_executable.has_value()) { // not built by RelaxVM, return the dso_mod directly return dso_mod; } - auto mod = vm_load_executable().cast(); - ffi::Function vm_initialization = mod.GetFunction("vm_initialization"); - CHECK(vm_initialization != nullptr) - << "ValueError: File `" << path - << "` is not built by RelaxVM, because `vm_initialization` does not exist"; - vm_initialization(static_cast(dev.device_type), static_cast(dev.device_id), - static_cast(AllocatorType::kPooled), static_cast(kDLCPU), 0, - static_cast(AllocatorType::kPooled)); + auto mod = (*vm_load_executable)().cast(); + Optional vm_initialization = mod->GetFunction("vm_initialization"); + if (!vm_initialization.has_value()) { + LOG(FATAL) << "ValueError: File `" << path + << "` is not built by RelaxVM, because `vm_initialization` does not exist"; + } + (*vm_initialization)(static_cast(dev.device_type), static_cast(dev.device_id), + static_cast(AllocatorType::kPooled), static_cast(kDLCPU), 0, + static_cast(AllocatorType::kPooled)); return mod; } diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc index ec302661bd0e..97af8bc9d3de 100644 --- a/src/runtime/disco/loader.cc +++ b/src/runtime/disco/loader.cc @@ -117,7 +117,7 @@ class ShardLoaderObj : public Object { public: /*! \brief Create a shard loader. */ static ObjectRef Create(const std::string& path_to_metadata, const std::string& metadata, - std::string shard_info, Module mod); + std::string shard_info, Optional mod); /*! \brief Load the i-th parameter */ NDArray Load(int weight_index) const; @@ -175,11 +175,10 @@ class ShardLoaderObj : public Object { }; ObjectRef ShardLoaderObj::Create(const std::string& path_to_metadata, const std::string& metadata, - std::string shard_info, Module mod) { - if (shard_info.empty() && mod.defined()) { - if (ffi::Function get_shard_info = mod->GetFunction("get_shard_info"); - get_shard_info != nullptr) { - shard_info = get_shard_info().cast(); + std::string shard_info, Optional mod) { + if (shard_info.empty() && mod.has_value()) { + if (auto get_shard_info = (*mod)->GetFunction("get_shard_info")) { + shard_info = (*get_shard_info)().cast(); } } ObjectPtr n = make_object(); @@ -195,9 +194,9 @@ ObjectRef ShardLoaderObj::Create(const std::string& path_to_metadata, const std: ShardInfo& shard_info = shards[name]; for (const ShardInfo::ShardFunc& shard_func : shard_info.funcs) { const std::string& name = shard_func.name; - if (ffi::Function f = mod.defined() ? mod->GetFunction(name, true) : nullptr; - f != nullptr) { - n->shard_funcs_[name] = f; + if (Optional f = + mod.has_value() ? (*mod)->GetFunction(name, true) : std::nullopt) { + n->shard_funcs_[name] = *f; } else if (const auto f = tvm::ffi::Function::GetGlobal(name)) { n->shard_funcs_[name] = *f; } else { diff --git a/src/runtime/hexagon/hexagon_common.cc b/src/runtime/hexagon/hexagon_common.cc index c5e62f39ac5e..491ded5730e6 100644 --- a/src/runtime/hexagon/hexagon_common.cc +++ b/src/runtime/hexagon/hexagon_common.cc @@ -32,7 +32,6 @@ #include #include -#include "../library_module.h" #include "HAP_debug.h" #include "HAP_perf.h" #include "hexagon_buffer.h" @@ -93,9 +92,9 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( - "runtime.module.loadfile_hexagon", [](ffi::PackedArgs args, ffi::Any* rv) { - ObjectPtr n = CreateDSOLibraryObject(args[0].cast()); - *rv = CreateModuleFromLibrary(n); + "ffi.Module.load_from_file.hexagon", [](ffi::PackedArgs args, ffi::Any* rv) { + auto floader = tvm::ffi::Function::GetGlobalRequired("ffi.Module.load_from_file.so"); + *rv = floader(args[0].cast(), "so"); }); }); diff --git a/src/runtime/hexagon/hexagon_module.cc b/src/runtime/hexagon/hexagon_module.cc index a5a8de45357a..9db6a6680b06 100644 --- a/src/runtime/hexagon/hexagon_module.cc +++ b/src/runtime/hexagon/hexagon_module.cc @@ -24,8 +24,8 @@ #include "hexagon_module.h" #include +#include #include -#include #include #include @@ -42,12 +42,11 @@ HexagonModuleNode::HexagonModuleNode(std::string data, std::string fmt, std::string bc_str) : data_(data), fmt_(fmt), fmap_(fmap), asm_(asm_str), obj_(obj_str), ir_(ir_str), bc_(bc_str) {} -ffi::Function HexagonModuleNode::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +Optional HexagonModuleNode::GetFunction(const String& name) { LOG(FATAL) << "HexagonModuleNode::GetFunction is not implemented."; } -String HexagonModuleNode::GetSource(const String& format) { +String HexagonModuleNode::InspectSource(const String& format) const { if (format == "s" || format == "asm") { return asm_; } @@ -57,7 +56,7 @@ String HexagonModuleNode::GetSource(const String& format) { return ""; } -void HexagonModuleNode::SaveToFile(const String& file_name, const String& format) { +void HexagonModuleNode::WriteToFile(const String& file_name, const String& format) const { std::string fmt = runtime::GetFileFormat(file_name, format); if (fmt == "so" || fmt == "dll" || fmt == "hexagon") { std::string meta_file = GetMetaFilePath(file_name); @@ -80,17 +79,22 @@ void HexagonModuleNode::SaveToFile(const String& file_name, const String& format } } -void HexagonModuleNode::SaveToBinary(dmlc::Stream* stream) { +ffi::Bytes HexagonModuleNode::SaveToBytes() const { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; stream->Write(fmt_); stream->Write(fmap_); stream->Write(data_); + return ffi::Bytes(buffer); } -Module HexagonModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string asm_str, - std::string obj_str, std::string ir_str, std::string bc_str) { +ffi::Module HexagonModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string asm_str, std::string obj_str, std::string ir_str, + std::string bc_str) { auto n = make_object(data, fmt, fmap, asm_str, obj_str, ir_str, bc_str); - return Module(n); + return ffi::Module(n); } } // namespace runtime diff --git a/src/runtime/hexagon/hexagon_module.h b/src/runtime/hexagon/hexagon_module.h index b8a830bc7c29..ae7174236622 100644 --- a/src/runtime/hexagon/hexagon_module.h +++ b/src/runtime/hexagon/hexagon_module.h @@ -20,8 +20,8 @@ #ifndef TVM_RUNTIME_HEXAGON_HEXAGON_MODULE_H_ #define TVM_RUNTIME_HEXAGON_HEXAGON_MODULE_H_ +#include #include -#include #include #include @@ -44,9 +44,10 @@ namespace runtime { * \param ir_str String with the disassembled LLVM IR source. * \param bc_str String with the bitcode LLVM IR. */ -Module HexagonModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string asm_str, - std::string obj_str, std::string ir_str, std::string bc_str); +ffi::Module HexagonModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string asm_str, std::string obj_str, std::string ir_str, + std::string bc_str); /*! \brief Module implementation for compiled Hexagon binaries. It is suitable @@ -54,21 +55,21 @@ Module HexagonModuleCreate(std::string data, std::string fmt, See docstring for HexagonModuleCreate for construction parameter details. */ -class HexagonModuleNode : public runtime::ModuleNode { +class HexagonModuleNode : public ffi::ModuleObj { public: HexagonModuleNode(std::string data, std::string fmt, std::unordered_map fmap, std::string asm_str, std::string obj_str, std::string ir_str, std::string bc_str); - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override; - String GetSource(const String& format) override; - const char* type_key() const final { return "hexagon"; } + Optional GetFunction(const String& name) final; + String InspectSource(const String& format) const final; + const char* kind() const final { return "hexagon"; } /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const override { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kDSOExportable | - ModulePropertyMask::kRunnable; + int GetPropertyMask() const final { + return ffi::Module::kBinarySerializable | ffi::Module::kCompilationExportable | + ffi::Module::kRunnable; } - void SaveToFile(const String& file_name, const String& format) override; - void SaveToBinary(dmlc::Stream* stream) override; + void WriteToFile(const String& file_name, const String& format) const final; + ffi::Bytes SaveToBytes() const final; protected: std::string data_; diff --git a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc index fd5fc9ee2bc1..96c45bfdf0d1 100644 --- a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc +++ b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc @@ -36,7 +36,6 @@ extern "C" { #include #include -#include "../../../library_module.h" #include "../../../minrpc/minrpc_server.h" #include "../../hexagon/hexagon_common.h" #include "../../hexagon/hexagon_device_api.h" @@ -335,9 +334,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_packed("tvm.hexagon.load_module", [](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto soname = args[0].cast(); - tvm::ObjectPtr n = - tvm::runtime::CreateDSOLibraryObject(soname); - *rv = CreateModuleFromLibrary(n); + auto floader = + tvm::ffi::Function::GetGlobalRequired("ffi.Module.load_from_file.so"); + *rv = floader(soname, "so"); }) .def_packed( "tvm.hexagon.get_profile_output", [](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { diff --git a/src/runtime/hexagon/rpc/simulator/rpc_server.cc b/src/runtime/hexagon/rpc/simulator/rpc_server.cc index 448cd0db9442..d511b0038f21 100644 --- a/src/runtime/hexagon/rpc/simulator/rpc_server.cc +++ b/src/runtime/hexagon/rpc/simulator/rpc_server.cc @@ -28,7 +28,6 @@ #include #include -#include "../../../library_module.h" #include "../../../minrpc/minrpc_server.h" #include "../../hexagon_common.h" #include "../../profiler/prof_utils.h" @@ -339,9 +338,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_packed("tvm.hexagon.load_module", [](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto soname = args[0].cast(); - tvm::ObjectPtr n = - tvm::runtime::CreateDSOLibraryObject(soname); - *rv = CreateModuleFromLibrary(n); + auto floader = + tvm::ffi::Function::GetGlobalRequired("ffi.Module.load_from_file.so"); + *rv = floader(soname, "so"); }) .def_packed( "tvm.hexagon.get_profile_output", [](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc deleted file mode 100644 index 24fc7518d6ad..000000000000 --- a/src/runtime/library_module.cc +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file module_util.cc - * \brief Utilities for module. - */ -#include "library_module.h" - -#include -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace runtime { - -// Library module that exposes symbols from a library. -class LibraryModuleNode final : public ModuleNode { - public: - explicit LibraryModuleNode(ObjectPtr lib, FFIFunctionWrapper wrapper) - : lib_(lib), packed_func_wrapper_(wrapper) {} - - const char* type_key() const final { return "library"; } - - /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; - }; - - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { - TVMFFISafeCallType faddr; - faddr = reinterpret_cast(lib_->GetSymbol(name.c_str())); - if (faddr == nullptr) return ffi::Function(); - return packed_func_wrapper_(faddr, sptr_to_self); - } - - private: - ObjectPtr lib_; - FFIFunctionWrapper packed_func_wrapper_; -}; - -ffi::Function WrapFFIFunction(TVMFFISafeCallType faddr, const ObjectPtr& sptr_to_self) { - return ffi::Function::FromPacked([faddr, sptr_to_self](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK_LT(rv->type_index(), ffi::TypeIndex::kTVMFFIStaticObjectBegin); - TVM_FFI_CHECK_SAFE_CALL((*faddr)(nullptr, reinterpret_cast(args.data()), - args.size(), reinterpret_cast(rv))); - }); -} - -void InitContextFunctions(std::function fgetsymbol) { -#define TVM_INIT_CONTEXT_FUNC(FuncName) \ - if (auto* fp = reinterpret_cast(fgetsymbol("__" #FuncName))) { \ - *fp = FuncName; \ - } - // Initialize the functions - TVM_INIT_CONTEXT_FUNC(TVMFFIFunctionCall); - TVM_INIT_CONTEXT_FUNC(TVMFFIErrorSetRaisedFromCStr); - TVM_INIT_CONTEXT_FUNC(TVMBackendGetFuncFromEnv); - TVM_INIT_CONTEXT_FUNC(TVMBackendAllocWorkspace); - TVM_INIT_CONTEXT_FUNC(TVMBackendFreeWorkspace); - TVM_INIT_CONTEXT_FUNC(TVMBackendParallelLaunch); - TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier); - -#undef TVM_INIT_CONTEXT_FUNC -} - -Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) { - std::string loadkey = "runtime.module.loadbinary_"; - std::string fkey = loadkey + type_key; - const auto f = tvm::ffi::Function::GetGlobal(fkey); - if (!f.has_value()) { - LOG(FATAL) << "Binary was created using {" << type_key - << "} but a loader of that name is not registered." - << "Perhaps you need to recompile with this runtime enabled."; - } - - return (*f)(static_cast(stream)).cast(); -} - -/*! - * \brief Load and append module blob to module list - * \param mblob The module blob. - * \param lib The library. - * \param root_module the output root module - * \param dso_ctx_addr the output dso module - */ -void ProcessLibraryBin(const char* mblob, ObjectPtr lib, - FFIFunctionWrapper packed_func_wrapper, runtime::Module* root_module, - runtime::ModuleNode** dso_ctx_addr = nullptr) { - ICHECK(mblob != nullptr); - uint64_t nbytes = 0; - for (size_t i = 0; i < sizeof(nbytes); ++i) { - uint64_t c = mblob[i]; - nbytes |= (c & 0xffUL) << (i * 8); - } - dmlc::MemoryFixedSizeStream fs(const_cast(mblob + sizeof(nbytes)), - static_cast(nbytes)); - dmlc::Stream* stream = &fs; - uint64_t size; - ICHECK(stream->Read(&size)); - std::vector modules; - std::vector import_tree_row_ptr; - std::vector import_tree_child_indices; - int num_dso_module = 0; - - for (uint64_t i = 0; i < size; ++i) { - std::string tkey; - ICHECK(stream->Read(&tkey)); - // "_lib" serves as a placeholder in the module import tree to indicate where - // to place the DSOModule - if (tkey == "_lib") { - auto dso_module = Module(make_object(lib, packed_func_wrapper)); - *dso_ctx_addr = dso_module.operator->(); - ++num_dso_module; - modules.emplace_back(dso_module); - ICHECK_EQ(num_dso_module, 1U) << "Multiple dso module detected, please upgrade tvm " - << " to the latest before exporting the module"; - } else if (tkey == "_import_tree") { - ICHECK(stream->Read(&import_tree_row_ptr)); - ICHECK(stream->Read(&import_tree_child_indices)); - } else { - auto m = LoadModuleFromBinary(tkey, stream); - modules.emplace_back(m); - } - } - - // if we are using old dll, we don't have import tree - // so that we can't reconstruct module relationship using import tree - if (import_tree_row_ptr.empty()) { - auto n = make_object(lib, packed_func_wrapper); - auto module_import_addr = ModuleInternal::GetImportsAddr(n.operator->()); - for (const auto& m : modules) { - module_import_addr->emplace_back(m); - } - *dso_ctx_addr = n.get(); - *root_module = Module(n); - } else { - for (size_t i = 0; i < modules.size(); ++i) { - for (size_t j = import_tree_row_ptr[i]; j < import_tree_row_ptr[i + 1]; ++j) { - auto module_import_addr = ModuleInternal::GetImportsAddr(modules[i].operator->()); - auto child_index = import_tree_child_indices[j]; - ICHECK(child_index < modules.size()); - module_import_addr->emplace_back(modules[child_index]); - } - } - - ICHECK(!modules.empty()) << "modules cannot be empty when import tree is present"; - // invariance: root module is always at location 0. - // The module order is collected via DFS - *root_module = modules[0]; - } -} - -Module CreateModuleFromLibrary(ObjectPtr lib, FFIFunctionWrapper packed_func_wrapper) { - InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); }); - auto n = make_object(lib, packed_func_wrapper); - // Load the imported modules - const char* library_bin = - reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_ffi_library_bin)); - - Module root_mod; - runtime::ModuleNode* dso_ctx_addr = nullptr; - if (library_bin != nullptr) { - ProcessLibraryBin(library_bin, lib, packed_func_wrapper, &root_mod, &dso_ctx_addr); - } else { - // Only have one single DSO Module - root_mod = Module(n); - dso_ctx_addr = root_mod.operator->(); - } - - // allow lookup of symbol from root (so all symbols are visible). - if (auto* ctx_addr = - reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_ffi_library_ctx))) { - *ctx_addr = dso_ctx_addr; - } - - return root_mod; -} -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h deleted file mode 100644 index 60ce95e2369b..000000000000 --- a/src/runtime/library_module.h +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file library_module.h - * \brief Module that builds from a libary of symbols. - */ -#ifndef TVM_RUNTIME_LIBRARY_MODULE_H_ -#define TVM_RUNTIME_LIBRARY_MODULE_H_ - -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace runtime { - -/*! \brief Load a module with the given type key directly from the stream. - * This function wraps the registry mechanism used to store type based deserializers - * for each runtime::Module sub-class. - * - * \param type_key The type key of the serialized module. - * \param stream A pointer to the stream containing the serialized module. - * \return module The deserialized module. - */ -Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream); - -/*! - * \brief Library is the common interface - * for storing data in the form of shared libaries. - * - * \sa dso_library.cc - * \sa system_library.cc - */ -class Library : public Object { - public: - // destructor. - virtual ~Library() {} - /*! - * \brief Get the symbol address for a given name. - * \param name The name of the symbol. - * \return The symbol. - */ - virtual void* GetSymbol(const char* name) = 0; - // NOTE: we do not explicitly create an type index and type_key here for libary. - // This is because we do not need dynamic type downcasting. -}; - -/*! - * \brief Wrap a TVMFFISafeCallType to packed function. - * \param faddr The function address - * \param mptr The module pointer node. - */ -ffi::Function WrapFFIFunction(TVMFFISafeCallType faddr, const ObjectPtr& mptr); - -/*! - * \brief Utility to initialize conext function symbols during startup - * \param fgetsymbol A symbol lookup function. - */ -void InitContextFunctions(std::function fgetsymbol); - -/*! - * \brief Helper classes to get into internal of a module. - */ -class ModuleInternal { - public: - // Get mutable reference of imports. - static std::vector* GetImportsAddr(ModuleNode* node) { return &(node->imports_); } -}; - -/*! - * \brief Type alias for function to wrap a TVMFFISafeCallType. - * \param The function address imported from a module. - * \param mptr The module pointer node. - * \return Packed function that wraps the invocation of the function at faddr. - */ -using FFIFunctionWrapper = - std::function& mptr)>; - -/*! \brief Return a library object interface over dynamic shared - * libraries in Windows and Linux providing support for - * loading/unloading and symbol lookup. - * \param Full path to shared library. - * \return Returns pointer to the Library providing symbol lookup. - */ -ObjectPtr CreateDSOLibraryObject(std::string library_path); - -/*! - * \brief Create a module from a library. - * - * \param lib The library. - * \param wrapper Optional function used to wrap a TVMBackendPackedCFunc, - * by default WrapFFIFunction is used. - * \param symbol_prefix Optional symbol prefix that can be used to search alternative symbols. - * - * \return The corresponding loaded module. - * - * \note This function can create multiple linked modules - * by parsing the binary blob section of the library. - */ -Module CreateModuleFromLibrary(ObjectPtr lib, - FFIFunctionWrapper wrapper = WrapFFIFunction); -} // namespace runtime -} // namespace tvm -#endif // TVM_RUNTIME_LIBRARY_MODULE_H_ diff --git a/src/runtime/metal/metal_module.h b/src/runtime/metal/metal_module.h index e2705a7a806b..213b6580b4e4 100644 --- a/src/runtime/metal/metal_module.h +++ b/src/runtime/metal/metal_module.h @@ -24,7 +24,7 @@ #ifndef TVM_RUNTIME_METAL_METAL_MODULE_H_ #define TVM_RUNTIME_METAL_METAL_MODULE_H_ -#include +#include #include #include @@ -46,9 +46,9 @@ static constexpr const int kMetalMaxNumDevice = 32; * \param fmt The format of the source, can be "metal" or "metallib" * \param source Optional, source file, concatenaed for debug dump */ -Module MetalModuleCreate(std::unordered_map smap, - std::unordered_map fmap, std::string fmt, - std::string source); +ffi::Module MetalModuleCreate(std::unordered_map smap, + std::unordered_map fmap, std::string fmt, + std::string source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_METAL_METAL_MODULE_H_ diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 33bb1705c8e4..71c46504c4d4 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -22,9 +22,9 @@ */ #include "metal_module.h" #include +#include #include #include -#include #include #include #include @@ -45,33 +45,37 @@ // Module to support thread-safe multi-GPU execution. // The runtime will contain a per-device module table // The modules will be lazily loaded -class MetalModuleNode final : public runtime::ModuleNode { +class MetalModuleNode final : public ffi::ModuleObj { public: explicit MetalModuleNode(std::unordered_map smap, std::unordered_map fmap, std::string fmt, std::string source) : smap_(smap), fmap_(fmap), fmt_(fmt), source_(source) {} - const char* type_key() const final { return "metal"; } + const char* kind() const final { return "metal"; } /*! \brief Get the property of the runtime module. */ int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + Optional GetFunction(const String& name) final; - void SaveToFile(const String& file_name, const String& format) final { + void WriteToFile(const String& file_name, const String& format) const final { LOG(FATAL) << "Do not support save to file, use save to binary and export instead"; } - void SaveToBinary(dmlc::Stream* stream) final { + ffi::Bytes SaveToBytes() const final { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; std::string version = kMetalModuleVersion; stream->Write(version); stream->Write(smap_); stream->Write(fmap_); stream->Write(fmt_); + return ffi::Bytes(buffer); } - String GetSource(const String& format) final { + String InspectSource(const String& format) const final { // return text source if available. return source_; } @@ -259,15 +263,14 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) LaunchParamConfig launch_param_config_; }; -ffi::Function MetalModuleNode::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +Optional MetalModuleNode::GetFunction(const String& name) { ffi::Function ret; AUTORELEASEPOOL { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); auto it = fmap_.find(name); if (it == fmap_.end()) { - ret = ffi::Function(); - return; + return std::nullopt; } const FunctionInfo& info = it->second; MetalWrappedFunc f; @@ -279,12 +282,12 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) return ret; } -Module MetalModuleCreate(std::unordered_map smap, - std::unordered_map fmap, std::string fmt, - std::string source) { +ffi::Module MetalModuleCreate(std::unordered_map smap, + std::unordered_map fmap, std::string fmt, + std::string source) { ObjectPtr n; AUTORELEASEPOOL { n = make_object(smap, fmap, fmt, source); }; - return Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ @@ -303,8 +306,9 @@ Module MetalModuleCreate(std::unordered_map smap, }); }); -Module MetalModuleLoadBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); +ffi::Module MetalModuleLoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; // version is reserved for future changes and // is discarded for now std::string ver; @@ -322,7 +326,7 @@ Module MetalModuleLoadBinary(void* strm) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("runtime.module.loadbinary_metal", MetalModuleLoadBinary); + refl::GlobalDef().def("ffi.Module.load_from_bytes.metal", MetalModuleLoadFromBytes); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/module.cc b/src/runtime/module.cc index cf19ff147f0c..12b58da2df2a 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -21,8 +21,10 @@ * \file module.cc * \brief TVM module system */ +#include #include #include +#include #include #include @@ -33,106 +35,6 @@ namespace tvm { namespace runtime { -void ModuleNode::Import(Module other) { - // specially handle rpc - if (!std::strcmp(this->type_key(), "rpc")) { - static auto fimport_ = tvm::ffi::Function::GetGlobalRequired("rpc.ImportRemoteModule"); - fimport_(GetRef(this), other); - return; - } - // cyclic detection. - std::unordered_set visited{other.operator->()}; - std::vector stack{other.operator->()}; - while (!stack.empty()) { - const ModuleNode* n = stack.back(); - stack.pop_back(); - for (const Module& m : n->imports_) { - const ModuleNode* next = m.operator->(); - if (visited.count(next)) continue; - visited.insert(next); - stack.push_back(next); - } - } - ICHECK(!visited.count(this)) << "Cyclic dependency detected during import"; - this->imports_.emplace_back(std::move(other)); -} - -ffi::Function ModuleNode::GetFunction(const String& name, bool query_imports) { - ModuleNode* self = this; - ffi::Function pf = self->GetFunction(name, GetObjectPtr(this)); - if (pf != nullptr) return pf; - if (query_imports) { - for (Module& m : self->imports_) { - pf = m.operator->()->GetFunction(name, query_imports); - if (pf != nullptr) { - return pf; - } - } - } - return pf; -} - -Module Module::LoadFromFile(const String& file_name, const String& format) { - std::string fmt = GetFileFormat(file_name, format); - ICHECK(fmt.length() != 0) << "Cannot deduce format of file " << file_name; - if (fmt == "dll" || fmt == "dylib" || fmt == "dso") { - fmt = "so"; - } - std::string load_f_name = "runtime.module.loadfile_" + fmt; - VLOG(1) << "Loading module from '" << file_name << "' of format '" << fmt << "'"; - const auto f = tvm::ffi::Function::GetGlobal(load_f_name); - ICHECK(f.has_value()) << "Loader for `." << format << "` files is not registered," - << " resolved to (" << load_f_name << ") in the global registry." - << "Ensure that you have loaded the correct runtime code, and" - << "that you are on the correct hardware architecture."; - Module m = (*f)(file_name, format).cast(); - return m; -} - -void ModuleNode::SaveToFile(const String& file_name, const String& format) { - LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile"; -} - -void ModuleNode::SaveToBinary(dmlc::Stream* stream) { - LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToBinary"; -} - -String ModuleNode::GetSource(const String& format) { - LOG(FATAL) << "Module[" << type_key() << "] does not support GetSource"; -} - -const ffi::Function* ModuleNode::GetFuncFromEnv(const String& name) { - std::lock_guard lock(mutex_); - auto it = import_cache_.find(name); - if (it != import_cache_.end()) return it->second.get(); - ffi::Function pf; - for (Module& m : this->imports_) { - pf = m.GetFunction(name, true); - if (pf != nullptr) break; - } - if (pf == nullptr) { - const auto f = tvm::ffi::Function::GetGlobal(name); - ICHECK(f.has_value()) << "Cannot find function " << name - << " in the imported modules or global registry." - << " If this involves ops from a contrib library like" - << " cuDNN, ensure TVM was built with the relevant" - << " library."; - import_cache_.insert(std::make_pair(name, std::make_shared(*f))); - return import_cache_.at(name).get(); - } else { - import_cache_.insert(std::make_pair(name, std::make_shared(pf))); - return import_cache_.at(name).get(); - } -} - -String ModuleNode::GetFormat() { - LOG(FATAL) << "Module[" << type_key() << "] does not support GetFormat"; -} - -bool ModuleNode::ImplementsFunction(const String& name, bool query_imports) { - return GetFunction(name, query_imports) != nullptr; -} - bool RuntimeEnabled(const String& target_str) { std::string target = target_str; std::string f_name; @@ -166,33 +68,26 @@ bool RuntimeEnabled(const String& target_str) { return tvm::ffi::Function::GetGlobal(f_name).has_value(); } +#define TVM_INIT_CONTEXT_FUNC(FuncName) \ + TVM_FFI_CHECK_SAFE_CALL( \ + TVMFFIEnvRegisterContextSymbol("__" #FuncName, reinterpret_cast(FuncName))) + TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("runtime.RuntimeEnabled", RuntimeEnabled) - .def("runtime.ModuleGetSource", - [](Module mod, std::string fmt) { return mod->GetSource(fmt); }) - .def("runtime.ModuleImportsSize", - [](Module mod) { return static_cast(mod->imports().size()); }) - .def("runtime.ModuleGetImport", - [](Module mod, int index) { return mod->imports().at(index); }) - .def("runtime.ModuleClearImports", [](Module mod) { mod->ClearImports(); }) - .def("runtime.ModuleGetTypeKey", [](Module mod) { return std::string(mod->type_key()); }) - .def("runtime.ModuleGetFormat", [](Module mod) { return mod->GetFormat(); }) - .def("runtime.ModuleLoadFromFile", Module::LoadFromFile) - .def("runtime.ModuleSaveToFile", - [](Module mod, String name, String fmt) { mod->SaveToFile(name, fmt); }) - .def("runtime.ModuleGetPropertyMask", [](Module mod) { return mod->GetPropertyMask(); }) - .def("runtime.ModuleImplementsFunction", - [](Module mod, String name, bool query_imports) { - return mod->ImplementsFunction(std::move(name), query_imports); - }) - .def("runtime.ModuleGetFunction", - [](Module mod, String name, bool query_imports) { - return mod->GetFunction(name, query_imports); - }) - .def("runtime.ModuleImport", [](Module mod, Module other) { mod->Import(other); }); + + // Initialize the functions + TVM_INIT_CONTEXT_FUNC(TVMFFIFunctionCall); + TVM_INIT_CONTEXT_FUNC(TVMFFIErrorSetRaisedFromCStr); + TVM_INIT_CONTEXT_FUNC(TVMBackendGetFuncFromEnv); + TVM_INIT_CONTEXT_FUNC(TVMBackendAllocWorkspace); + TVM_INIT_CONTEXT_FUNC(TVMBackendFreeWorkspace); + TVM_INIT_CONTEXT_FUNC(TVMBackendParallelLaunch); + TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier); + + refl::GlobalDef().def("runtime.RuntimeEnabled", RuntimeEnabled); }); +#undef TVM_INIT_CONTEXT_FUNC + } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index 3fefae597f21..3e0981146afc 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -456,7 +456,7 @@ struct BufferDescriptor { // To make the call thread-safe, we create a thread-local kernel table // and lazily install new kernels into the kernel table when the kernel is called. // The kernels are recycled when the module get destructed. -class OpenCLModuleNodeBase : public ModuleNode { +class OpenCLModuleNodeBase : public ffi::ModuleObj { public: // Kernel table reference entry. struct KTRefEntry { @@ -472,14 +472,14 @@ class OpenCLModuleNodeBase : public ModuleNode { */ virtual cl::OpenCLWorkspace* GetGlobalWorkspace(); - const char* type_key() const final { return workspace_->type_key.c_str(); } + const char* kind() const final { return workspace_->type_key.c_str(); } /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override; + Optional GetFunction(const String& name) override; // Initialize the programs virtual void Init() = 0; @@ -509,14 +509,14 @@ class OpenCLModuleNode : public OpenCLModuleNodeBase { std::unordered_map fmap, std::string source) : OpenCLModuleNodeBase(fmap), data_(data), fmt_(fmt), source_(source) {} - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + Optional GetFunction(const String& name) final; // Return true if OpenCL program for the requested function and device was created bool IsProgramCreated(const std::string& func_name, int device_id); - void SaveToFile(const String& file_name, const String& format) final; - void SaveToBinary(dmlc::Stream* stream) final; + void WriteToFile(const String& file_name, const String& format) const final; + ffi::Bytes SaveToBytes() const final; void SetPreCompiledPrograms(const std::string& bytes); std::string GetPreCompiledPrograms(); - String GetSource(const String& format) final; + String InspectSource(const String& format) const final; // Initialize the programs void Init() override; diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 1c61eeb59635..a8e3b6fc20b6 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -135,11 +135,11 @@ cl::OpenCLWorkspace* OpenCLModuleNodeBase::GetGlobalWorkspace() { return cl::OpenCLWorkspace::Global(); } -ffi::Function OpenCLModuleNodeBase::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +Optional OpenCLModuleNodeBase::GetFunction(const String& name) { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); auto it = fmap_.find(name); - if (it == fmap_.end()) return ffi::Function(); + if (it == fmap_.end()) return std::nullopt; const FunctionInfo& info = it->second; OpenCLWrappedFunc f; std::vector arg_size(info.arg_types.size()); @@ -160,7 +160,7 @@ ffi::Function OpenCLModuleNodeBase::GetFunction(const String& name, return PackFuncVoidAddr(f, info.arg_types); } -void OpenCLModuleNode::SaveToFile(const String& file_name, const String& format) { +void OpenCLModuleNode::WriteToFile(const String& file_name, const String& format) const { std::string fmt = GetFileFormat(file_name, format); ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); @@ -168,13 +168,17 @@ void OpenCLModuleNode::SaveToFile(const String& file_name, const String& format) SaveBinaryToFile(file_name, data_); } -void OpenCLModuleNode::SaveToBinary(dmlc::Stream* stream) { +ffi::Bytes OpenCLModuleNode::SaveToBytes() const { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; stream->Write(fmt_); stream->Write(fmap_); stream->Write(data_); + return ffi::Bytes(buffer); } -String OpenCLModuleNode::GetSource(const String& format) { +String OpenCLModuleNode::InspectSource(const String& format) const { if (format == fmt_) return data_; if (fmt_ == "cl") { return data_; @@ -201,7 +205,7 @@ void OpenCLModuleNode::Init() { } // split into source artifacts for each kernel - parsed_kernels_ = SplitKernels(GetSource("cl")); + parsed_kernels_ = SplitKernels(InspectSource("cl")); ICHECK(!parsed_kernels_.empty()) << "The OpenCL module expects a kernel delimited " << "source from code generation, but no kernel " << "delimiter was found."; @@ -345,8 +349,8 @@ std::string OpenCLModuleNode::GetPreCompiledPrograms() { return data; } -ffi::Function OpenCLModuleNode::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +Optional OpenCLModuleNode::GetFunction(const String& name) { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); if (name == "opencl.GetPreCompiledPrograms") { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { @@ -357,18 +361,19 @@ ffi::Function OpenCLModuleNode::GetFunction(const String& name, this->SetPreCompiledPrograms(args[0].cast()); }); } - return OpenCLModuleNodeBase::GetFunction(name, sptr_to_self); + return OpenCLModuleNodeBase::GetFunction(name); } -Module OpenCLModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string source) { +ffi::Module OpenCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string source) { auto n = make_object(data, fmt, fmap, source); n->Init(); - return Module(n); + return ffi::Module(n); } // Load module from module. -Module OpenCLModuleLoadFile(const std::string& file_name, const String& format) { +ffi::Module OpenCLModuleLoadFile(const std::string& file_name, const String& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -378,8 +383,9 @@ Module OpenCLModuleLoadFile(const std::string& file_name, const String& format) return OpenCLModuleCreate(data, fmt, fmap, std::string()); } -Module OpenCLModuleLoadBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); +ffi::Module OpenCLModuleLoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; std::string data; std::unordered_map fmap; std::string fmt; @@ -392,9 +398,9 @@ Module OpenCLModuleLoadBinary(void* strm) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("runtime.module.loadfile_cl", OpenCLModuleLoadFile) - .def("runtime.module.loadfile_clbin", OpenCLModuleLoadFile) - .def("runtime.module.loadbinary_opencl", OpenCLModuleLoadBinary); + .def("ffi.Module.load_from_file.cl", OpenCLModuleLoadFile) + .def("ffi.Module.load_from_file.clbin", OpenCLModuleLoadFile) + .def("ffi.Module.load_from_bytes.opencl", OpenCLModuleLoadFromBytes); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/opencl_module.h b/src/runtime/opencl/opencl_module.h index 198adc6cb216..18afad56a0c8 100644 --- a/src/runtime/opencl/opencl_module.h +++ b/src/runtime/opencl/opencl_module.h @@ -44,8 +44,9 @@ namespace runtime { * \param fmap The map function information map of each function. * \param source Generated OpenCL kernels. */ -Module OpenCLModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string source); +ffi::Module OpenCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string source); /*! * \brief Create a opencl module from SPIRV. @@ -54,9 +55,9 @@ Module OpenCLModuleCreate(std::string data, std::string fmt, * \param spirv_text The concatenated text representation of SPIRV modules. * \param fmap The map function information map of each function. */ -Module OpenCLModuleCreate(const std::unordered_map& shaders, - const std::string& spirv_text, - std::unordered_map fmap); +ffi::Module OpenCLModuleCreate(const std::unordered_map& shaders, + const std::string& spirv_text, + std::unordered_map fmap); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_ diff --git a/src/runtime/opencl/opencl_module_spirv.cc b/src/runtime/opencl/opencl_module_spirv.cc index 7d281694decb..5b90e0b566c7 100644 --- a/src/runtime/opencl/opencl_module_spirv.cc +++ b/src/runtime/opencl/opencl_module_spirv.cc @@ -39,9 +39,9 @@ class OpenCLSPIRVModuleNode : public OpenCLModuleNodeBase { std::unordered_map fmap) : OpenCLModuleNodeBase(fmap), shaders_(shaders), spirv_text_(spirv_text) {} - void SaveToFile(const String& file_name, const String& format) final; - void SaveToBinary(dmlc::Stream* stream) final; - String GetSource(const String&) final { return spirv_text_; } + void WriteToFile(const String& file_name, const String& format) const final; + ffi::Bytes SaveToBytes() const final; + String InspectSource(const String& format) const final { return spirv_text_; } void Init() override; cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, @@ -52,14 +52,18 @@ class OpenCLSPIRVModuleNode : public OpenCLModuleNodeBase { std::string spirv_text_; }; -void OpenCLSPIRVModuleNode::SaveToFile(const String& file_name, const String& format) { +void OpenCLSPIRVModuleNode::WriteToFile(const String& file_name, const String& format) const { // TODO(masahi): How SPIRV binaries should be save to a file? LOG(FATAL) << "Not implemented."; } -void OpenCLSPIRVModuleNode::SaveToBinary(dmlc::Stream* stream) { +ffi::Bytes OpenCLSPIRVModuleNode::SaveToBytes() const { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; stream->Write(fmap_); stream->Write(shaders_); + return ffi::Bytes(buffer); } void OpenCLSPIRVModuleNode::Init() { @@ -125,12 +129,12 @@ cl_kernel OpenCLSPIRVModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenC return kernel; } -Module OpenCLModuleCreate(const std::unordered_map& shaders, - const std::string& spirv_text, - std::unordered_map fmap) { +ffi::Module OpenCLModuleCreate(const std::unordered_map& shaders, + const std::string& spirv_text, + std::unordered_map fmap) { auto n = make_object(shaders, spirv_text, fmap); n->Init(); - return Module(n); + return ffi::Module(n); } } // namespace runtime diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index 4cce0d40d168..9d4c01d62366 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -792,59 +792,60 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("runtime.profiling.DeviceWrapper", [](Device dev) { return DeviceWrapper(dev); }); }); -ffi::Function ProfileFunction(Module mod, std::string func_name, int device_type, int device_id, - int warmup_iters, Array collectors) { +ffi::Function ProfileFunction(ffi::Module mod, std::string func_name, int device_type, + int device_id, int warmup_iters, Array collectors) { // Module::GetFunction is not const, so this lambda has to be mutable - return ffi::Function::FromPacked( - [=](const ffi::AnyView* args, int32_t num_args, ffi::Any* ret) mutable { - ffi::Function f = mod.GetFunction(func_name); - CHECK(f.defined()) << "There is no function called \"" << func_name << "\" in the module"; - Device dev{static_cast(device_type), device_id}; - - // warmup - for (int i = 0; i < warmup_iters; i++) { - f.CallPacked(args, num_args, ret); - } - - for (auto& collector : collectors) { - collector->Init({DeviceWrapper(dev)}); - } - std::vector> results; - results.reserve(collectors.size()); - std::vector> collector_data; - collector_data.reserve(collectors.size()); - for (auto& collector : collectors) { - ObjectRef o = collector->Start(dev); - // If not defined, then the collector cannot time this device. - if (o.defined()) { - collector_data.push_back({collector, o}); - } - } + return ffi::Function::FromPacked([=](const ffi::AnyView* args, int32_t num_args, + ffi::Any* ret) mutable { + auto optf = mod->GetFunction(func_name); + CHECK(optf.has_value()) << "There is no function called \"" << func_name << "\" in the module"; + auto f = *optf; + Device dev{static_cast(device_type), device_id}; + + // warmup + for (int i = 0; i < warmup_iters; i++) { + f.CallPacked(args, num_args, ret); + } + + for (auto& collector : collectors) { + collector->Init({DeviceWrapper(dev)}); + } + std::vector> results; + results.reserve(collectors.size()); + std::vector> collector_data; + collector_data.reserve(collectors.size()); + for (auto& collector : collectors) { + ObjectRef o = collector->Start(dev); + // If not defined, then the collector cannot time this device. + if (o.defined()) { + collector_data.push_back({collector, o}); + } + } - // TODO(tkonolige): repeated calls if the runtime is small? - f.CallPacked(args, num_args, ret); + // TODO(tkonolige): repeated calls if the runtime is small? + f.CallPacked(args, num_args, ret); - for (auto& kv : collector_data) { - results.push_back(kv.first->Stop(kv.second)); - } - Map combined_results; - for (auto m : results) { - for (auto p : m) { - // assume that there is no shared metric name between collectors - combined_results.Set(p.first, p.second); - } - } - *ret = combined_results; - }); + for (auto& kv : collector_data) { + results.push_back(kv.first->Stop(kv.second)); + } + Map combined_results; + for (auto m : results) { + for (auto p : m) { + // assume that there is no shared metric name between collectors + combined_results.Set(p.first, p.second); + } + } + *ret = combined_results; + }); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "runtime.profiling.ProfileFunction", - [](Module mod, String func_name, int device_type, int device_id, int warmup_iters, + [](ffi::Module mod, String func_name, int device_type, int device_id, int warmup_iters, Array collectors) { - if (mod->type_key() == std::string("rpc")) { + if (mod->kind() == std::string("rpc")) { LOG(FATAL) << "Profiling a module over RPC is not yet supported"; // because we can't send // MetricCollectors over rpc. diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index a871a41f0f86..13b14e13e0e7 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -22,6 +22,7 @@ */ #include "rocm_module.h" +#include #include #include #include @@ -45,7 +46,7 @@ namespace runtime { // hipModule_t is a per-GPU module // The runtime will contain a per-device module table // The modules will be lazily loaded -class ROCMModuleNode : public runtime::ModuleNode { +class ROCMModuleNode : public ffi::ModuleObj { public: explicit ROCMModuleNode(std::string data, std::string fmt, std::unordered_map fmap, @@ -63,13 +64,13 @@ class ROCMModuleNode : public runtime::ModuleNode { } } - const char* type_key() const final { return "hip"; } + const char* kind() const final { return "hip"; } int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + Optional GetFunction(const String& name) final; - void SaveToFile(const String& file_name, const String& format) final { + void WriteToFile(const String& file_name, const String& format) const final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); // note: llvm and asm formats are not laodable, so we don't save them @@ -78,13 +79,17 @@ class ROCMModuleNode : public runtime::ModuleNode { SaveBinaryToFile(file_name, data_); } - void SaveToBinary(dmlc::Stream* stream) final { + ffi::Bytes SaveToBytes() const final { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; stream->Write(fmt_); stream->Write(fmap_); stream->Write(data_); + return ffi::Bytes(buffer); } - String GetSource(const String& format) final { + String InspectSource(const String& format) const final { if (format == fmt_) { return data_; } @@ -192,25 +197,25 @@ class ROCMWrappedFunc { LaunchParamConfig launch_param_config_; }; -ffi::Function ROCMModuleNode::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +Optional ROCMModuleNode::GetFunction(const String& name) { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); auto it = fmap_.find(name); - if (it == fmap_.end()) return ffi::Function(); + if (it == fmap_.end()) return std::nullopt; const FunctionInfo& info = it->second; ROCMWrappedFunc f; f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags); return PackFuncPackedArgAligned(f, info.arg_types); } -Module ROCMModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string hip_source, - std::string assembly) { +ffi::Module ROCMModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string hip_source, std::string assembly) { auto n = make_object(data, fmt, fmap, hip_source, assembly); - return Module(n); + return ffi::Module(n); } -Module ROCMModuleLoadFile(const std::string& file_name, const std::string& format) { +ffi::Module ROCMModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -220,8 +225,9 @@ Module ROCMModuleLoadFile(const std::string& file_name, const std::string& forma return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string()); } -Module ROCMModuleLoadBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); +ffi::Module ROCMModuleLoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; std::string data; std::unordered_map fmap; std::string fmt; @@ -234,10 +240,10 @@ Module ROCMModuleLoadBinary(void* strm) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("runtime.module.loadbinary_hsaco", ROCMModuleLoadBinary) - .def("runtime.module.loadbinary_hip", ROCMModuleLoadBinary) - .def("runtime.module.loadfile_hsaco", ROCMModuleLoadFile) - .def("runtime.module.loadfile_hip", ROCMModuleLoadFile); + .def("ffi.Module.load_from_bytes.hsaco", ROCMModuleLoadFromBytes) + .def("ffi.Module.load_from_bytes.hip", ROCMModuleLoadFromBytes) + .def("ffi.Module.load_from_file.hsaco", ROCMModuleLoadFile) + .def("ffi.Module.load_from_file.hip", ROCMModuleLoadFile); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rocm/rocm_module.h b/src/runtime/rocm/rocm_module.h index c17e123c1a12..ee6f29f43edb 100644 --- a/src/runtime/rocm/rocm_module.h +++ b/src/runtime/rocm/rocm_module.h @@ -24,7 +24,7 @@ #ifndef TVM_RUNTIME_ROCM_ROCM_MODULE_H_ #define TVM_RUNTIME_ROCM_ROCM_MODULE_H_ -#include +#include #include #include @@ -47,9 +47,9 @@ static constexpr const int kMaxNumGPUs = 32; * \param fmap The map function information map of each function. * \param rocm_source Optional, rocm source file */ -Module ROCMModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string rocm_source, - std::string assembly); +ffi::Module ROCMModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string rocm_source, std::string assembly); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_ROCM_ROCM_MODULE_H_ diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index 3dea9dc82239..e1282c17878a 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -593,13 +593,13 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { << " Error caught from session constructor " << constructor_name << ":\n" << e.what(); } - auto opt_con_ret = con_ret.as(); + auto opt_con_ret = con_ret.as(); // Legacy ABI translation ICHECK(opt_con_ret.has_value()) << "Server[" << name_ << "]:" << " Constructor " << constructor_name << " need to return an RPCModule"; - Module mod = opt_con_ret.value(); - std::string tkey = mod->type_key(); + ffi::Module mod = opt_con_ret.value(); + std::string tkey = mod->kind(); ICHECK_EQ(tkey, "rpc") << "Constructor " << constructor_name << " to return an RPCModule"; serving_session_ = RPCModuleGetSession(mod); this->ReturnVoid(); diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 3094c9ca13a2..bcf661960f06 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -170,7 +170,7 @@ class RPCWrappedFunc : public Object { }; // RPC that represents a remote module session. -class RPCModuleNode final : public ModuleNode { +class RPCModuleNode final : public ffi::ModuleObj { public: RPCModuleNode(void* module_handle, std::shared_ptr sess) : module_handle_(module_handle), sess_(sess) {} @@ -186,11 +186,11 @@ class RPCModuleNode final : public ModuleNode { } } - const char* type_key() const final { return "rpc"; } + const char* kind() const final { return "rpc"; } /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const final { return ModulePropertyMask::kRunnable; } + int GetPropertyMask() const final { return ffi::Module::ModulePropertyMask::kRunnable; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + Optional GetFunction(const String& name) final { if (name == "CloseRPCConnection") { return ffi::Function([this](ffi::PackedArgs, ffi::Any*) { sess_->Shutdown(); }); } @@ -199,15 +199,10 @@ class RPCModuleNode final : public ModuleNode { return WrapRemoteFunc(sess_->GetFunction(name)); } else { InitRemoteFunc(&remote_mod_get_function_, "tvm.rpc.server.ModuleGetFunction"); - return remote_mod_get_function_(GetRef(this), name, true); + return remote_mod_get_function_(GetRef(this), name, true); } } - String GetSource(const String& format) final { - LOG(FATAL) << "GetSource for rpc Module is not supported"; - throw; - } - ffi::Function GetTimeEvaluator(const std::string& name, Device dev, int number, int repeat, int min_repeat_ms, int limit_zero_time_iterations, int cooldown_interval_ms, int repeats_to_cooldown, @@ -220,25 +215,25 @@ class RPCModuleNode final : public ModuleNode { if (module_handle_ != nullptr) { return remote_get_time_evaluator_( - GetRef(this), name, static_cast(dev.device_type), dev.device_id, number, + GetRef(this), name, static_cast(dev.device_type), dev.device_id, number, repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, repeats_to_cooldown, cache_flush_bytes, f_preproc_name); } else { return remote_get_time_evaluator_( - Optional(std::nullopt), name, static_cast(dev.device_type), dev.device_id, - number, repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, - repeats_to_cooldown, cache_flush_bytes, f_preproc_name); + Optional(std::nullopt), name, static_cast(dev.device_type), + dev.device_id, number, repeat, min_repeat_ms, limit_zero_time_iterations, + cooldown_interval_ms, repeats_to_cooldown, cache_flush_bytes, f_preproc_name); } } - Module LoadModule(std::string name) { + ffi::Module LoadModule(std::string name) { InitRemoteFunc(&remote_load_module_, "tvm.rpc.server.load_module"); return remote_load_module_(name); } - void ImportModule(Module other) { + void ImportModule(ffi::Module other) { InitRemoteFunc(&remote_import_module_, "tvm.rpc.server.ImportModule"); - remote_import_module_(GetRef(this), other); + remote_import_module_(GetRef(this), other); } const std::shared_ptr& sess() { return sess_; } @@ -266,22 +261,22 @@ class RPCModuleNode final : public ModuleNode { // The local channel std::shared_ptr sess_; // remote function to get time evaluator - ffi::TypedFunction, std::string, int, int, int, int, int, int, int, - int, int, std::string)> + ffi::TypedFunction, std::string, int, int, int, int, int, int, + int, int, int, std::string)> remote_get_time_evaluator_; // remote function getter for modules. - ffi::TypedFunction remote_mod_get_function_; + ffi::TypedFunction remote_mod_get_function_; // remote function getter for load module - ffi::TypedFunction remote_load_module_; + ffi::TypedFunction remote_load_module_; // remote function getter for load module - ffi::TypedFunction remote_import_module_; + ffi::TypedFunction remote_import_module_; }; void* RPCWrappedFunc::UnwrapRemoteValueToHandle(const AnyView& arg) const { // TODO(tqchen): only support Module unwrapping for now. if (arg.type_index() == ffi::TypeIndex::kTVMFFIModule) { - Module mod = arg.cast(); - std::string tkey = mod->type_key(); + ffi::Module mod = arg.cast(); + std::string tkey = mod->kind(); ICHECK_EQ(tkey, "rpc") << "ValueError: Cannot pass a non-RPC module to remote"; auto* rmod = static_cast(mod.operator->()); ICHECK(rmod->sess() == sess_) @@ -309,7 +304,7 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(ffi::PackedArgs args, ffi::Any* rv) ICHECK_EQ(args.size(), 2); void* handle = args[1].cast(); auto n = make_object(handle, sess_); - *rv = Module(n); + *rv = ffi::Module(n); } else if (type_index == ffi::TypeIndex::kTVMFFINDArray || type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr) { ICHECK_EQ(args.size(), 3); @@ -335,14 +330,14 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(ffi::PackedArgs args, ffi::Any* rv) } } -Module CreateRPCSessionModule(std::shared_ptr sess) { +ffi::Module CreateRPCSessionModule(std::shared_ptr sess) { auto n = make_object(nullptr, sess); RPCSession::InsertToSessionTable(sess); - return Module(n); + return ffi::Module(n); } -std::shared_ptr RPCModuleGetSession(Module mod) { - std::string tkey = mod->type_key(); +std::shared_ptr RPCModuleGetSession(ffi::Module mod) { + std::string tkey = mod->kind(); ICHECK_EQ(tkey, "rpc") << "ValueError: Cannot pass a non-RPC module to remote"; auto* rmod = static_cast(mod.operator->()); return rmod->sess(); @@ -402,7 +397,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.RPCTimeEvaluator", - [](Optional opt_mod, std::string name, int device_type, int device_id, + [](Optional opt_mod, std::string name, int device_type, int device_id, int number, int repeat, int min_repeat_ms, int limit_zero_time_iterations, int cooldown_interval_ms, int repeats_to_cooldown, int cache_flush_bytes, std::string f_preproc_name) { @@ -410,8 +405,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ dev.device_type = static_cast(device_type); dev.device_id = device_id; if (opt_mod.defined()) { - Module m = opt_mod.value(); - std::string tkey = m->type_key(); + ffi::Module m = opt_mod.value(); + std::string tkey = m->kind(); if (tkey == "rpc") { return static_cast(m.operator->()) ->GetTimeEvaluator(name, dev, number, repeat, min_repeat_ms, @@ -425,10 +420,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ << "Cannot find " << f_preproc_name << " in the global function"; f_preproc = *pf_preproc; } - ffi::Function pf = m.GetFunction(name, true); - CHECK(pf != nullptr) << "Cannot find " << name << "` in the global registry"; + Optional pf = m->GetFunction(name); + CHECK(pf.has_value()) << "Cannot find " << name << "` in the global registry"; return profiling::WrapTimeEvaluator( - pf, dev, number, repeat, min_repeat_ms, limit_zero_time_iterations, + *pf, dev, number, repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, repeats_to_cooldown, cache_flush_bytes, f_preproc); } } else { @@ -455,9 +450,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tvm.rpc.server.ImportModule", - [](Module parent, Module child) { parent->Import(child); }) + [](ffi::Module parent, ffi::Module child) { parent->ImportModule(child); }) .def("tvm.rpc.server.ModuleGetFunction", - [](Module parent, std::string name, bool query_imports) { + [](ffi::Module parent, std::string name, bool query_imports) { return parent->GetFunction(name, query_imports); }); }); @@ -467,26 +462,26 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("rpc.LoadRemoteModule", - [](Module sess, std::string name) { - std::string tkey = sess->type_key(); + [](ffi::Module sess, std::string name) { + std::string tkey = sess->kind(); ICHECK_EQ(tkey, "rpc"); return static_cast(sess.operator->())->LoadModule(name); }) .def("rpc.ImportRemoteModule", - [](Module parent, Module child) { - std::string tkey = parent->type_key(); + [](ffi::Module parent, ffi::Module child) { + std::string tkey = parent->kind(); ICHECK_EQ(tkey, "rpc"); static_cast(parent.operator->())->ImportModule(child); }) .def_packed("rpc.SessTableIndex", [](ffi::PackedArgs args, ffi::Any* rv) { - Module m = args[0].cast(); - std::string tkey = m->type_key(); + ffi::Module m = args[0].cast(); + std::string tkey = m->kind(); ICHECK_EQ(tkey, "rpc"); *rv = static_cast(m.operator->())->sess()->table_index(); }) .def("tvm.rpc.NDArrayFromRemoteOpaqueHandle", - [](Module mod, void* remote_array, DLTensor* template_tensor, Device dev, + [](ffi::Module mod, void* remote_array, DLTensor* template_tensor, Device dev, void* ndarray_handle) -> NDArray { return NDArrayFromRemoteOpaqueHandle(RPCModuleGetSession(mod), remote_array, template_tensor, dev, ndarray_handle); diff --git a/src/runtime/rpc/rpc_pipe_impl.cc b/src/runtime/rpc/rpc_pipe_impl.cc index e50c6a456eaf..22619289d053 100644 --- a/src/runtime/rpc/rpc_pipe_impl.cc +++ b/src/runtime/rpc/rpc_pipe_impl.cc @@ -76,7 +76,7 @@ class PipeChannel final : public RPCChannel { pid_t child_pid_; }; -Module CreatePipeClient(std::vector cmd) { +ffi::Module CreatePipeClient(std::vector cmd) { int parent2child[2]; int child2parent[2]; ICHECK_EQ(pipe(parent2child), 0); diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index c0ec2067eb5f..c0e09ec004ba 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -267,7 +267,7 @@ class RPCSession { /*! \brief Insert the current session to the session table.*/ static void InsertToSessionTable(std::shared_ptr sess); // friend declaration - friend Module CreateRPCSessionModule(std::shared_ptr sess); + friend ffi::Module CreateRPCSessionModule(std::shared_ptr sess); }; /*! @@ -341,14 +341,14 @@ class RPCObjectRef : public ObjectRef { * \param sess The RPC session of the global module. * \return The created module. */ -Module CreateRPCSessionModule(std::shared_ptr sess); +ffi::Module CreateRPCSessionModule(std::shared_ptr sess); /*! * \brief Get the session module from a RPC session Module. * \param mod The input module(must be an RPCModule). * \return The internal RPCSession. */ -std::shared_ptr RPCModuleGetSession(Module mod); +std::shared_ptr RPCModuleGetSession(ffi::Module mod); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index 5ed34051cf55..d2f141ee21e0 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -105,8 +105,8 @@ std::shared_ptr RPCConnect(std::string url, int port, std::string k return endpt; } -Module RPCClientConnect(std::string url, int port, std::string key, bool enable_logging, - ffi::PackedArgs init_seq) { +ffi::Module RPCClientConnect(std::string url, int port, std::string key, bool enable_logging, + ffi::PackedArgs init_seq) { auto endpt = RPCConnect(url, port, "client:" + key, enable_logging, init_seq); return CreateRPCSessionModule(CreateClientSession(endpt)); } diff --git a/src/runtime/static_library.cc b/src/runtime/static_library.cc index 5ad331d27d0a..b816fb600e1e 100644 --- a/src/runtime/static_library.cc +++ b/src/runtime/static_library.cc @@ -24,6 +24,7 @@ */ #include "./static_library.h" +#include #include #include #include @@ -42,30 +43,34 @@ namespace { * \brief A '.o' library which can be linked into the final output library by export_library. * Can be used by external codegen tools which can produce a ready-to-link artifact. */ -class StaticLibraryNode final : public runtime::ModuleNode { +class StaticLibraryNode final : public ffi::ModuleObj { public: - ~StaticLibraryNode() override = default; + const char* kind() const final { return "static_library"; } - const char* type_key() const final { return "static_library"; } - - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + Optional GetFunction(const String& name) final { + const ObjectPtr& sptr_to_self = ffi::GetObjectPtr(this); if (name == "get_func_names") { return ffi::Function( [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = func_names_; }); } else { - return {}; + return std::nullopt; } } - void SaveToBinary(dmlc::Stream* stream) final { + ffi::Bytes SaveToBytes() const final { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; stream->Write(data_); std::vector func_names; for (const auto func_name : func_names_) func_names.push_back(func_name); stream->Write(func_names); + return Bytes(buffer); } - static Module LoadFromBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); + static ffi::Module LoadFromBytes(ffi::Bytes bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; auto n = make_object(); // load data std::string data; @@ -77,10 +82,10 @@ class StaticLibraryNode final : public runtime::ModuleNode { ICHECK(stream->Read(&func_names)) << "Loading func names failed"; for (auto func_name : func_names) n->func_names_.push_back(String(func_name)); - return Module(n); + return ffi::Module(n); } - void SaveToFile(const String& file_name, const String& format) final { + void WriteToFile(const String& file_name, const String& format) const final { VLOG(0) << "Saving static library of " << data_.size() << " bytes implementing " << FuncNames() << " to '" << file_name << "'"; SaveBinaryToFile(file_name, data_); @@ -88,14 +93,14 @@ class StaticLibraryNode final : public runtime::ModuleNode { /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const override { - return runtime::ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kDSOExportable; + return ffi::Module::kBinarySerializable | ffi::Module::kCompilationExportable; } - bool ImplementsFunction(const String& name, bool query_imports) final { + bool ImplementsFunction(const String& name) final { return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end(); } - std::string FuncNames() { + std::string FuncNames() const { std::ostringstream os; os << "["; bool first = true; @@ -119,19 +124,19 @@ class StaticLibraryNode final : public runtime::ModuleNode { } // namespace -Module LoadStaticLibrary(const std::string& filename, Array func_names) { +ffi::Module LoadStaticLibrary(const std::string& filename, Array func_names) { auto node = make_object(); LoadBinaryFromFile(filename, &node->data_); node->func_names_ = std::move(func_names); VLOG(0) << "Loaded static library from '" << filename << "' implementing " << node->FuncNames(); - return Module(node); + return ffi::Module(node); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.ModuleLoadStaticLibrary", LoadStaticLibrary) - .def("runtime.module.loadbinary_static_library", StaticLibraryNode::LoadFromBinary); + .def("ffi.Module.load_from_bytes.static_library", StaticLibraryNode::LoadFromBytes); }); } // namespace runtime diff --git a/src/runtime/static_library.h b/src/runtime/static_library.h index 196d2448b93f..8a5600fc0588 100644 --- a/src/runtime/static_library.h +++ b/src/runtime/static_library.h @@ -43,7 +43,7 @@ namespace runtime { * \brief Returns a static library with the contents loaded from filename which exports * func_names with the usual packed-func calling convention. */ -Module LoadStaticLibrary(const std::string& filename, Array func_names); +ffi::Module LoadStaticLibrary(const std::string& filename, Array func_names); } // namespace runtime } // namespace tvm diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 1bec6f5016eb..ef6fbe6373af 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -161,7 +161,7 @@ void LoadHeader(dmlc::Stream* strm) { STREAM_CHECK(version == VM_VERSION, "version"); } -void VMExecutable::SaveToBinary(dmlc::Stream* stream) { +ffi::Bytes VMExecutable::SaveToBytes() const { std::string code; // Initialize the stream object. dmlc::MemoryStringStream strm(&code); @@ -178,21 +178,16 @@ void VMExecutable::SaveToBinary(dmlc::Stream* stream) { // Code section. SaveCodeSection(&strm); - stream->Write(code); + return ffi::Bytes(code); } -void VMExecutable::SaveToFile(const String& file_name, const String& format) { - std::string data; - dmlc::MemoryStringStream writer(&data); - dmlc::SeekStream* strm = &writer; - VMExecutable::SaveToBinary(strm); - runtime::SaveBinaryToFile(file_name, data); +void VMExecutable::WriteToFile(const String& file_name, const String& format) const { + runtime::SaveBinaryToFile(file_name, VMExecutable::SaveToBytes()); } -Module VMExecutable::LoadFromBinary(void* stream) { +ffi::Module VMExecutable::LoadFromBytes(const ffi::Bytes& bytes) { std::string code; - static_cast(stream)->Read(&code); - dmlc::MemoryStringStream strm(&code); + dmlc::MemoryFixedSizeStream strm(const_cast(bytes.data()), bytes.size()); ObjectPtr exec = make_object(); @@ -208,26 +203,20 @@ Module VMExecutable::LoadFromBinary(void* stream) { // Code section. exec->LoadCodeSection(&strm); - return Module(exec); + return ffi::Module(exec); } -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("runtime.module.loadbinary_relax.VMExecutable", - VMExecutable::LoadFromBinary); -}); - -Module VMExecutable::LoadFromFile(const String& file_name) { +ffi::Module VMExecutable::LoadFromFile(const String& file_name) { std::string data; runtime::LoadBinaryFromFile(file_name, &data); - dmlc::MemoryStringStream reader(&data); - dmlc::Stream* strm = &reader; - return VMExecutable::LoadFromBinary(reinterpret_cast(strm)); + return VMExecutable::LoadFromBytes(ffi::Bytes(data)); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("runtime.module.loadfile_relax.VMExecutable", VMExecutable::LoadFromFile); + refl::GlobalDef() + .def("ffi.Module.load_from_file.relax.VMExecutable", VMExecutable::LoadFromFile) + .def("ffi.Module.load_from_bytes.relax.VMExecutable", VMExecutable::LoadFromBytes); }); void VMFuncInfo::Save(dmlc::Stream* strm) const { @@ -254,9 +243,9 @@ bool VMFuncInfo::Load(dmlc::Stream* strm) { return true; } -void VMExecutable::SaveGlobalSection(dmlc::Stream* strm) { strm->Write(func_table); } +void VMExecutable::SaveGlobalSection(dmlc::Stream* strm) const { strm->Write(func_table); } -void VMExecutable::SaveConstantSection(dmlc::Stream* strm) { +void VMExecutable::SaveConstantSection(dmlc::Stream* strm) const { strm->Write(static_cast(this->constants.size())); for (const auto& it : this->constants) { if (auto opt_nd = it.as()) { @@ -291,7 +280,7 @@ void VMExecutable::SaveConstantSection(dmlc::Stream* strm) { } } -void VMExecutable::SaveCodeSection(dmlc::Stream* strm) { +void VMExecutable::SaveCodeSection(dmlc::Stream* strm) const { strm->Write(instr_offset); strm->Write(instr_data); } @@ -394,16 +383,16 @@ std::string RegNameToStr(RegName reg) { return "%" + std::to_string(reg); } -Module VMExecutable::VMLoadExecutable() const { +ffi::Module VMExecutable::VMLoadExecutable() const { ObjectPtr vm = VirtualMachine::Create(); vm->LoadExecutable(GetObjectPtr(const_cast(this))); - return Module(vm); + return ffi::Module(vm); } -Module VMExecutable::VMProfilerLoadExecutable() const { +ffi::Module VMExecutable::VMProfilerLoadExecutable() const { ObjectPtr vm = VirtualMachine::CreateProfiler(); vm->LoadExecutable(GetObjectPtr(const_cast(this))); - return Module(vm); + return ffi::Module(vm); } bool VMExecutable::HasFunction(const String& name) const { return func_map.count(name); } diff --git a/src/runtime/vm/ndarray_cache_support.cc b/src/runtime/vm/ndarray_cache_support.cc index d91669016d78..cfd979cc6f24 100644 --- a/src/runtime/vm/ndarray_cache_support.cc +++ b/src/runtime/vm/ndarray_cache_support.cc @@ -302,11 +302,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ // This param module node can be useful to get param dict in RPC mode // when the remote already have loaded parameters from file. -class ParamModuleNode : public runtime::ModuleNode { +class ParamModuleNode : public ffi::ModuleObj { public: - const char* type_key() const final { return "param_module"; } + const char* kind() const final { return "param_module"; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + Optional GetFunction(const String& name) final { if (name == "get_params") { auto params = params_; return ffi::Function([params](ffi::PackedArgs args, ffi::Any* rv) { *rv = params; }); @@ -343,16 +343,16 @@ class ParamModuleNode : public runtime::ModuleNode { return result; } - static Module Create(const std::string& prefix, int num_params) { + static ffi::Module Create(const std::string& prefix, int num_params) { auto n = make_object(); n->params_ = GetParams(prefix, num_params); - return Module(n); + return ffi::Module(n); } - static Module CreateByName(const Array& names) { + static ffi::Module CreateByName(const Array& names) { auto n = make_object(); n->params_ = GetParamByName(names); - return Module(n); + return ffi::Module(n); } private: diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index c28e30084fc1..c4fdedd815a9 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -300,12 +300,13 @@ class VirtualMachineImpl : public VirtualMachine { * \param name The name of the function. * \return The result function, can return ffi::Function(nullptr) if nothing is found. */ - ffi::Function GetFuncFromImports(const String& name) { + Optional GetFuncFromImports(const String& name) { for (auto& lib : this->imports_) { - ffi::Function func = lib->GetFunction(name, true); - if (func.defined()) return func; + if (auto opt_func = lib.cast()->GetFunction(name, true)) { + return *opt_func; + } } - return ffi::Function(nullptr); + return std::nullopt; } /*! * \brief Initialize function pool. @@ -452,7 +453,7 @@ class VirtualMachineImpl : public VirtualMachine { void VirtualMachineImpl::LoadExecutable(ObjectPtr exec) { this->exec_ = exec; - this->imports_ = exec_->imports(); + this->imports_ = exec->imports(); } void VirtualMachineImpl::Init(const std::vector& devices, @@ -508,7 +509,7 @@ void VirtualMachineImpl::SetInput(std::string func_name, bool with_param_module, for (int i = 0; i < args.size(); ++i) { if (with_param_module && i == args.size() - 1) { // call param func to get the arguments(usually corresponds to param pack.) - func_args[i] = (args[i].cast()).GetFunction("get_params")(); + func_args[i] = (args[i].cast())->GetFunction("get_params").value()(); } else { func_args[i] = ConvertArgToDevice(args[i], devices[0], allocators[0]); } @@ -620,9 +621,9 @@ Optional VirtualMachineImpl::GetClosureInternal(const String& func_na } else { ICHECK(finfo.kind == VMFuncInfo::FuncKind::kVMTIRFunc) << "Cannot support closure with function kind " << static_cast(finfo.kind); - ffi::Function tir_func = GetFuncFromImports("__vmtir__" + finfo.name); - ICHECK(tir_func != nullptr) << "Cannot find underlying compiled tir function of VMTIRFunc " - << finfo.name; + Optional tir_func = GetFuncFromImports("__vmtir__" + finfo.name); + ICHECK(tir_func.has_value()) << "Cannot find underlying compiled tir function of VMTIRFunc " + << finfo.name; auto impl = ffi::Function([this, finfo, tir_func](ffi::PackedArgs args, ffi::Any* rv) { // Per convention, ctx ptr is a VirtualMachine* VirtualMachine* ctx_ptr = static_cast(args[0].cast()); @@ -637,8 +638,8 @@ Optional VirtualMachineImpl::GetClosureInternal(const String& func_na void* reg_anylist_handle = reg_file.data(); void* const_anylist_handle = this->const_pool_.data(); void* func_anylist_handle = this->func_pool_.data(); - tir_func(static_cast(ctx_ptr), reg_anylist_handle, const_anylist_handle, - func_anylist_handle); + (*tir_func)(static_cast(ctx_ptr), reg_anylist_handle, const_anylist_handle, + func_anylist_handle); // Return value always stored after inputs. *rv = reg_file[finfo.num_args]; }); @@ -696,16 +697,16 @@ void VirtualMachineImpl::InitFuncPool() { const VMFuncInfo& info = exec_->func_table[func_index]; if (info.kind == VMFuncInfo::FuncKind::kPackedFunc) { // only look through imports first - ffi::Function func = GetFuncFromImports(info.name); - if (!func.defined()) { + Optional func = GetFuncFromImports(info.name); + if (!func.has_value()) { const auto p_func = tvm::ffi::Function::GetGlobal(info.name); - if (p_func.has_value()) func = *(p_func); + if (p_func.has_value()) func = *p_func; } - ICHECK(func.defined()) + ICHECK(func.has_value()) << "Error: Cannot find ffi::Function " << info.name << " in either Relax VM kernel library, or in TVM runtime ffi::Function registry, or in " "global Relax functions of the VM executable"; - func_pool_[func_index] = func; + func_pool_[func_index] = *func; } else { ICHECK(info.kind == VMFuncInfo::FuncKind::kVMFunc || @@ -951,8 +952,8 @@ std::string VirtualMachineImpl::_GetFunctionParamName(std::string func_name, int ffi::Function VirtualMachineImpl::_LookupFunction(const String& name) { if (Optional opt = this->GetClosureInternal(name, true)) { - return ffi::Function([clo = opt.value(), _self = GetRef(this)](ffi::PackedArgs args, - ffi::Any* rv) -> void { + return ffi::Function([clo = opt.value(), _self = GetRef(this)]( + ffi::PackedArgs args, ffi::Any* rv) -> void { auto* self = const_cast(_self.as()); ICHECK(self); self->InvokeClosurePacked(clo, args, rv); @@ -972,7 +973,8 @@ ffi::Function VirtualMachineImpl::_LookupFunction(const String& name) { */ class VirtualMachineProfiler : public VirtualMachineImpl { public: - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { + Optional GetFunction(const String& name) override { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (name == "profile") { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { std::string f_name = args[0].cast(); @@ -1017,7 +1019,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl { } }); } else { - return VirtualMachineImpl::GetFunction(name, sptr_to_self); + return VirtualMachineImpl::GetFunction(name); } } diff --git a/src/runtime/vulkan/vulkan_module.cc b/src/runtime/vulkan/vulkan_module.cc index 81136c28cd3c..a5fb6c2293fa 100644 --- a/src/runtime/vulkan/vulkan_module.cc +++ b/src/runtime/vulkan/vulkan_module.cc @@ -30,13 +30,14 @@ namespace tvm { namespace runtime { namespace vulkan { -Module VulkanModuleCreate(std::unordered_map smap, - std::unordered_map fmap, std::string source) { +ffi::Module VulkanModuleCreate(std::unordered_map smap, + std::unordered_map fmap, + std::string source) { auto n = make_object(smap, fmap, source); - return Module(n); + return ffi::Module(n); } -Module VulkanModuleLoadFile(const std::string& file_name, const String& format) { +ffi::Module VulkanModuleLoadFile(const std::string& file_name, const String& format) { std::string data; std::unordered_map smap; std::unordered_map fmap; @@ -53,8 +54,9 @@ Module VulkanModuleLoadFile(const std::string& file_name, const String& format) return VulkanModuleCreate(smap, fmap, ""); } -Module VulkanModuleLoadBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); +ffi::Module VulkanModuleLoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; std::unordered_map smap; std::unordered_map fmap; @@ -68,8 +70,8 @@ Module VulkanModuleLoadBinary(void* strm) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("runtime.module.loadfile_vulkan", VulkanModuleLoadFile) - .def("runtime.module.loadbinary_vulkan", VulkanModuleLoadBinary); + .def("ffi.Module.load_from_file.vulkan", VulkanModuleLoadFile) + .def("ffi.Module.load_from_bytes.vulkan", VulkanModuleLoadFromBytes); }); } // namespace vulkan diff --git a/src/runtime/vulkan/vulkan_module.h b/src/runtime/vulkan/vulkan_module.h index 878e096f5ac1..ea853721bfa2 100644 --- a/src/runtime/vulkan/vulkan_module.h +++ b/src/runtime/vulkan/vulkan_module.h @@ -20,6 +20,8 @@ #ifndef TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_ #define TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_ +#include + #include #include @@ -29,8 +31,9 @@ namespace tvm { namespace runtime { namespace vulkan { -Module VulkanModuleCreate(std::unordered_map smap, - std::unordered_map fmap, std::string source); +ffi::Module VulkanModuleCreate(std::unordered_map smap, + std::unordered_map fmap, + std::string source); } // namespace vulkan diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index db81c959dccd..2f50a0154658 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -205,11 +205,11 @@ VulkanModuleNode::~VulkanModuleNode() { } } -ffi::Function VulkanModuleNode::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +Optional VulkanModuleNode::GetFunction(const String& name) { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); auto it = fmap_.find(name); - if (it == fmap_.end()) return ffi::Function(); + if (it == fmap_.end()) return std::nullopt; const FunctionInfo& info = it->second; VulkanWrappedFunc f; size_t num_buffer_args = NumBufferArgs(info.arg_types); @@ -403,7 +403,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, return pe; } -void VulkanModuleNode::SaveToFile(const String& file_name, const String& format) { +void VulkanModuleNode::WriteToFile(const String& file_name, const String& format) const { std::string fmt = GetFileFormat(file_name, format); ICHECK_EQ(fmt, fmt_) << "Can only save to customized format vulkan"; std::string meta_file = GetMetaFilePath(file_name); @@ -417,13 +417,17 @@ void VulkanModuleNode::SaveToFile(const String& file_name, const String& format) SaveBinaryToFile(file_name, data_bin); } -void VulkanModuleNode::SaveToBinary(dmlc::Stream* stream) { +ffi::Bytes VulkanModuleNode::SaveToBytes() const { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; stream->Write(fmt_); stream->Write(fmap_); stream->Write(smap_); + return ffi::Bytes(buffer); } -String VulkanModuleNode::GetSource(const String& format) { +String VulkanModuleNode::InspectSource(const String& format) const { // can only return disassembly code. return source_; } diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h b/src/runtime/vulkan/vulkan_wrapped_func.h index 9b6f3703f34f..2ff90568de9d 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.h +++ b/src/runtime/vulkan/vulkan_wrapped_func.h @@ -80,29 +80,29 @@ class VulkanWrappedFunc { mutable std::array, kVulkanMaxNumDevice> scache_; }; -class VulkanModuleNode final : public runtime::ModuleNode { +class VulkanModuleNode final : public ffi::ModuleObj { public: explicit VulkanModuleNode(std::unordered_map smap, std::unordered_map fmap, std::string source) : smap_(smap), fmap_(fmap), source_(source) {} ~VulkanModuleNode(); - const char* type_key() const final { return "vulkan"; } + const char* kind() const final { return "vulkan"; } /*! \brief Get the property of the runtime module. */ int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + Optional GetFunction(const String& name) final; std::shared_ptr GetPipeline(size_t device_id, const std::string& func_name, size_t num_pack_args); - void SaveToFile(const String& file_name, const String& format) final; + void WriteToFile(const String& file_name, const String& format) const final; - void SaveToBinary(dmlc::Stream* stream) final; - String GetSource(const String& format) final; + ffi::Bytes SaveToBytes() const final; + String InspectSource(const String& format) const final; private: // function information table. diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index e9d35c4496e7..70c23c546bbb 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -123,13 +123,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef().def("testing.ErrorTest", ErrorTest); }); -class FrontendTestModuleNode : public runtime::ModuleNode { +class FrontendTestModuleNode : public ffi::ModuleObj { public: - const char* type_key() const final { return "frontend_test"; } + const char* kind() const final { return "frontend_test"; } static constexpr const char* kAddFunctionName = "__add_function"; - virtual ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self); + virtual ffi::Optional GetFunction(const String& name); private: std::unordered_map functions_; @@ -137,11 +137,11 @@ class FrontendTestModuleNode : public runtime::ModuleNode { constexpr const char* FrontendTestModuleNode::kAddFunctionName; -ffi::Function FrontendTestModuleNode::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +ffi::Optional FrontendTestModuleNode::GetFunction(const String& name) { + ffi::Module self_strong_ref = GetRef(this); if (name == kAddFunctionName) { - return ffi::TypedFunction( - [this, sptr_to_self](std::string func_name, ffi::Function pf) { + return ffi::Function::FromTyped( + [this, self_strong_ref](std::string func_name, ffi::Function pf) { CHECK_NE(func_name, kAddFunctionName) << "func_name: cannot be special function " << kAddFunctionName; functions_[func_name] = pf; @@ -150,15 +150,15 @@ ffi::Function FrontendTestModuleNode::GetFunction(const String& name, auto it = functions_.find(name); if (it == functions_.end()) { - return ffi::Function(); + return std::nullopt; } return it->second; } -runtime::Module NewFrontendTestModule() { +ffi::Module NewFrontendTestModule() { auto n = make_object(); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 9b650c9aaa43..96075450183c 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -38,12 +38,10 @@ #include #include -#include "../runtime/library_module.h" - namespace tvm { namespace codegen { -runtime::Module Build(IRModule mod, Target target) { +ffi::Module Build(IRModule mod, Target target) { if (transform::PassContext::Current() ->GetConfig("tir.disable_assert", Bool(false)) .value()) { @@ -54,66 +52,42 @@ runtime::Module Build(IRModule mod, Target target) { std::string build_f_name = "target.build." + target->kind->name; const auto bf = tvm::ffi::Function::GetGlobal(build_f_name); ICHECK(bf.has_value()) << build_f_name << " is not enabled"; - return (*bf)(mod, target).cast(); + return (*bf)(mod, target).cast(); } /*! \brief Helper class to serialize module */ class ModuleSerializer { public: - explicit ModuleSerializer(runtime::Module mod) : mod_(mod) { Init(); } + explicit ModuleSerializer(ffi::Module mod) : mod_(mod) { Init(); } void SerializeModuleToBytes(dmlc::Stream* stream, bool export_dso) { - // Only have one DSO module and it is in the root, then - // we will not produce import_tree_. - bool has_import_tree = true; - - if (export_dso) { - has_import_tree = !mod_->imports().empty(); - } - - uint64_t sz = 0; - if (has_import_tree) { - // we will append one key for _import_tree - // The layout is the same as before: binary_size, key, logic, key, logic... - sz = mod_group_vec_.size() + 1; - } else { - // Keep the old behaviour - sz = mod_->imports().size(); - } - stream->Write(sz); - + // Always _import_tree + stream->Write(import_tree_row_ptr_); + stream->Write(import_tree_child_indices_); for (const auto& group : mod_group_vec_) { ICHECK_NE(group.size(), 0) << "Every allocated group must have at least one module"; // we prioritize export dso when a module is both serializable and exportable if (export_dso) { - if (group[0]->IsDSOExportable()) { - if (has_import_tree) { - std::string mod_type_key = "_lib"; - stream->Write(mod_type_key); - } - } else if (group[0]->IsBinarySerializable()) { + if (group[0]->GetPropertyMask() & ffi::Module::kCompilationExportable) { + std::string mod_type_key = "_lib"; + stream->Write(mod_type_key); + } else if (group[0]->GetPropertyMask() & ffi::Module::kBinarySerializable) { ICHECK_EQ(group.size(), 1U) << "Non DSO module is never merged"; - std::string mod_type_key = group[0]->type_key(); + std::string mod_type_key = group[0]->kind(); stream->Write(mod_type_key); - group[0]->SaveToBinary(stream); + std::string bytes = group[0]->SaveToBytes(); + stream->Write(bytes); } } else { - ICHECK(group[0]->IsBinarySerializable()) - << group[0]->type_key() << " is not binary serializable."; + ICHECK(group[0]->GetPropertyMask() & ffi::Module::kBinarySerializable) + << group[0]->kind() << " is not binary serializable."; ICHECK_EQ(group.size(), 1U) << "Non DSO module is never merged"; - std::string mod_type_key = group[0]->type_key(); + std::string mod_type_key = group[0]->kind(); stream->Write(mod_type_key); - group[0]->SaveToBinary(stream); + std::string bytes = group[0]->SaveToBytes(); + stream->Write(bytes); } } - - // Write _import_tree key if we have - if (has_import_tree) { - std::string import_key = "_import_tree"; - stream->Write(import_key); - stream->Write(import_tree_row_ptr_); - stream->Write(import_tree_child_indices_); - } } private: @@ -127,13 +101,13 @@ class ModuleSerializer { // This function merges all the DSO exportable module into // a single one as this is also what happens in the final hierachy void CreateModuleIndex() { - std::unordered_set visited{mod_.operator->()}; - std::vector stack{mod_.operator->()}; + std::unordered_set visited{mod_.operator->()}; + std::vector stack{mod_.operator->()}; uint64_t module_index = 0; - auto fpush_imports_to_stack = [&](runtime::ModuleNode* node) { - for (runtime::Module m : node->imports()) { - runtime::ModuleNode* next = m.operator->(); + auto fpush_imports_to_stack = [&](ffi::ModuleObj* node) { + for (Any m : node->imports()) { + ffi::ModuleObj* next = m.cast().operator->(); if (visited.count(next) == 0) { visited.insert(next); stack.push_back(next); @@ -141,7 +115,7 @@ class ModuleSerializer { } }; - std::vector dso_exportable_boundary; + std::vector dso_exportable_boundary; // Create module index that merges all dso module into a single group. // @@ -154,16 +128,16 @@ class ModuleSerializer { // Phase 0: only expand non-dso-module and record the boundary. while (!stack.empty()) { - runtime::ModuleNode* n = stack.back(); + ffi::ModuleObj* n = stack.back(); stack.pop_back(); - if (n->IsDSOExportable()) { + if (n->GetPropertyMask() & ffi::Module::kCompilationExportable) { // do not recursively expand dso modules // we will expand in phase 1 dso_exportable_boundary.emplace_back(n); } else { // expand the non-dso modules mod2index_[n] = module_index++; - mod_group_vec_.emplace_back(std::vector({n})); + mod_group_vec_.emplace_back(std::vector({n})); fpush_imports_to_stack(n); } } @@ -173,22 +147,22 @@ class ModuleSerializer { // This index is chosen so that all the DSO's parents are // allocated before this index, and children will be allocated after uint64_t dso_module_index = module_index++; - mod_group_vec_.emplace_back(std::vector()); + mod_group_vec_.emplace_back(std::vector()); // restart visiting the stack using elements in dso exportable boundary stack = std::move(dso_exportable_boundary); // Phase 1: expand the children of dso modules. while (!stack.empty()) { - runtime::ModuleNode* n = stack.back(); + ffi::ModuleObj* n = stack.back(); stack.pop_back(); - if (n->IsDSOExportable()) { + if (n->GetPropertyMask() & ffi::Module::kCompilationExportable) { mod_group_vec_[dso_module_index].emplace_back(n); mod2index_[n] = dso_module_index; } else { mod2index_[n] = module_index++; - mod_group_vec_.emplace_back(std::vector({n})); + mod_group_vec_.emplace_back(std::vector({n})); } fpush_imports_to_stack(n); } @@ -200,8 +174,8 @@ class ModuleSerializer { for (size_t parent_index = 0; parent_index < mod_group_vec_.size(); ++parent_index) { child_indices.clear(); for (const auto* m : mod_group_vec_[parent_index]) { - for (runtime::Module im : m->imports()) { - uint64_t mod_index = mod2index_.at(im.operator->()); + for (Any im : m->imports()) { + uint64_t mod_index = mod2index_.at(im.cast().operator->()); // skip cycle when dso modules are merged together if (mod_index != parent_index) { child_indices.emplace_back(mod_index); @@ -218,8 +192,8 @@ class ModuleSerializer { CHECK_LT(parent_index, child_indices[0]) << "RuntimeError: Cannot export due to multiple dso-exportables " << "that cannot be merged without creating a cycle in the import tree. " - << "Related module keys: parent=" << mod_group_vec_[parent_index][0]->type_key() - << ", child=" << mod_group_vec_[child_indices[0]][0]->type_key(); + << "Related module keys: parent=" << mod_group_vec_[parent_index][0]->kind() + << ", child=" << mod_group_vec_[child_indices[0]][0]->kind(); } // insert the child indices import_tree_child_indices_.insert(import_tree_child_indices_.end(), child_indices.begin(), @@ -228,16 +202,16 @@ class ModuleSerializer { } } - runtime::Module mod_; + ffi::Module mod_; // construct module to index - std::unordered_map mod2index_; + std::unordered_map mod2index_; // index -> module group - std::vector> mod_group_vec_; + std::vector> mod_group_vec_; std::vector import_tree_row_ptr_{0}; std::vector import_tree_child_indices_; }; -std::string SerializeModuleToBytes(const runtime::Module& mod, bool export_dso) { +std::string SerializeModuleToBytes(const ffi::Module& mod, bool export_dso) { std::string bin; dmlc::MemoryStringStream ms(&bin); dmlc::Stream* stream = &ms; @@ -247,16 +221,18 @@ std::string SerializeModuleToBytes(const runtime::Module& mod, bool export_dso) return bin; } -runtime::Module DeserializeModuleFromBytes(std::string blob) { +ffi::Module DeserializeModuleFromBytes(std::string blob) { dmlc::MemoryStringStream ms(&blob); dmlc::Stream* stream = &ms; - uint64_t size; - ICHECK(stream->Read(&size)); - std::vector modules; + std::vector modules; std::vector import_tree_row_ptr; std::vector import_tree_child_indices; + stream->Read(&import_tree_row_ptr); + stream->Read(&import_tree_child_indices); + + uint64_t size = import_tree_row_ptr.size() - 1; for (uint64_t i = 0; i < size; ++i) { std::string tkey; ICHECK(stream->Read(&tkey)); @@ -267,29 +243,32 @@ runtime::Module DeserializeModuleFromBytes(std::string blob) { ICHECK(stream->Read(&import_tree_row_ptr)); ICHECK(stream->Read(&import_tree_child_indices)); } else { - auto m = runtime::LoadModuleFromBinary(tkey, stream); + std::string bytes; + ICHECK(stream->Read(&bytes)); + auto loader = ffi::Function::GetGlobal("ffi.Module.load_from_bytes." + tkey); + ICHECK(loader.has_value()) << "ffi.Module.load_from_bytes." << tkey << " is not enabled"; + auto m = (*loader)(ffi::Bytes(bytes)).cast(); modules.emplace_back(m); } } for (size_t i = 0; i < modules.size(); ++i) { for (size_t j = import_tree_row_ptr[i]; j < import_tree_row_ptr[i + 1]; ++j) { - auto module_import_addr = runtime::ModuleInternal::GetImportsAddr(modules[i].operator->()); auto child_index = import_tree_child_indices[j]; ICHECK(child_index < modules.size()); - module_import_addr->emplace_back(modules[child_index]); + modules[i]->ImportModule(modules[child_index]); } } ICHECK(!modules.empty()) << "modules cannot be empty when import tree is present"; // invariance: root module is always at location 0. // The module order is collected via DFS - runtime::Module root_mod = modules[0]; + ffi::Module root_mod = modules[0]; return root_mod; } -std::string PackImportsToBytes(const runtime::Module& mod) { - std::string bin = SerializeModuleToBytes(mod); +std::string PackImportsToBytes(const ffi::Module& mod) { + std::string bin = SerializeModuleToBytes(mod, /*export_dso*/ true); uint64_t nbytes = bin.length(); std::string header; @@ -299,14 +278,14 @@ std::string PackImportsToBytes(const runtime::Module& mod) { return header + bin; } -std::string PackImportsToC(const runtime::Module& mod, bool system_lib, +std::string PackImportsToC(const ffi::Module& mod, bool system_lib, const std::string& c_symbol_prefix) { if (c_symbol_prefix.length() != 0) { CHECK(system_lib) << "c_symbol_prefix advanced option should be used in conjuction with system-lib"; } - std::string mdev_blob_name = c_symbol_prefix + runtime::symbol::tvm_ffi_library_bin; + std::string mdev_blob_name = c_symbol_prefix + ffi::symbol::tvm_ffi_library_bin; std::string blob = PackImportsToBytes(mod); // translate to C program @@ -332,10 +311,10 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib, } os << "\n};\n"; if (system_lib) { - os << "extern int TVMBackendRegisterSystemLibSymbol(const char*, void*);\n"; + os << "extern int TVMFFIEnvRegisterSystemLibSymbol(const char*, void*);\n"; os << "static int " << mdev_blob_name << "_reg_ = " - << "TVMBackendRegisterSystemLibSymbol(\"" << mdev_blob_name << "\", (void*)" - << mdev_blob_name << ");\n"; + << "TVMFFIEnvRegisterSystemLibSymbol(\"" << mdev_blob_name << "\", (void*)" << mdev_blob_name + << ");\n"; } os << "#ifdef __cplusplus\n" << "}\n" @@ -343,9 +322,9 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib, return os.str(); } -runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib, - const std::string& llvm_target_string, - const std::string& c_symbol_prefix) { +ffi::Module PackImportsToLLVM(const ffi::Module& mod, bool system_lib, + const std::string& llvm_target_string, + const std::string& c_symbol_prefix) { if (c_symbol_prefix.length() != 0) { CHECK(system_lib) << "c_symbol_prefix advanced option should be used in conjuction with system-lib"; @@ -359,7 +338,7 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib, const auto codegen_f = tvm::ffi::Function::GetGlobal(codegen_f_name); ICHECK(codegen_f.has_value()) << "codegen.codegen_blob is not presented."; return (*codegen_f)(ffi::Bytes(blob), system_lib, llvm_target_string, c_symbol_prefix) - .cast(); + .cast(); } TVM_FFI_STATIC_INIT_BLOCK({ @@ -372,9 +351,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.ModuleImportsBlobName", - []() -> std::string { return runtime::symbol::tvm_ffi_library_bin; }) + []() -> std::string { return ffi::symbol::tvm_ffi_library_bin; }) .def("runtime.ModulePackImportsToNDArray", - [](const runtime::Module& mod) { + [](const ffi::Module& mod) { std::string buffer = PackImportsToBytes(mod); ffi::Shape::index_type size = buffer.size(); DLDataType uchar; diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 0cf218054320..9439af440b82 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -263,7 +263,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { } }; -runtime::Module BuildAMDGPU(IRModule mod, Target target) { +ffi::Module BuildAMDGPU(IRModule mod, Target target) { LLVMInstance llvm_instance; With llvm_target(llvm_instance, target); diff --git a/src/target/llvm/codegen_blob.cc b/src/target/llvm/codegen_blob.cc index 3d48f57513d0..fc2acfddfb81 100644 --- a/src/target/llvm/codegen_blob.cc +++ b/src/target/llvm/codegen_blob.cc @@ -77,7 +77,7 @@ std::unique_ptr CodeGenBlob(const std::string& data, bool system_l llvm_target->SetTargetMetadata(module.get()); module->setDataLayout(tm->createDataLayout()); auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false); - std::string mdev_blob_name = c_symbol_prefix + runtime::symbol::tvm_ffi_library_bin; + std::string mdev_blob_name = c_symbol_prefix + ffi::symbol::tvm_ffi_library_bin; auto* tvm_ffi_library_bin = new llvm::GlobalVariable( *module, blob_value->getType(), true, llvm::GlobalValue::ExternalLinkage, blob_value, @@ -151,11 +151,11 @@ std::unique_ptr CodeGenBlob(const std::string& data, bool system_l llvm::FunctionType::get(void_ty, false), llvm::GlobalValue::InternalLinkage, llvm::Twine("__cxx_global_var_init"), module.get()); - // Create TVMBackendRegisterSystemLibSymbol function + // Create TVMFFIEnvRegisterSystemLibSymbol function llvm::Function* tvm_backend_fn = llvm::Function::Create(llvm::FunctionType::get(int32_ty, {int8_ptr_ty, int8_ptr_ty}, false), llvm::GlobalValue::ExternalLinkage, - llvm::Twine("TVMBackendRegisterSystemLibSymbol"), module.get()); + llvm::Twine("TVMFFIEnvRegisterSystemLibSymbol"), module.get()); // Set necessary fn sections auto get_static_init_section_specifier = [&triple]() -> std::string { diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 6271d4edbe30..eebbd5b64fd4 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -146,10 +146,10 @@ void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, if (system_lib_prefix_.has_value() && !target_c_runtime) { // We will need this in environment for backward registration. // Defined in include/tvm/runtime/c_backend_api.h: - // int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr); + // int TVMFFIEnvRegisterSystemLibSymbol(const char* name, void* ptr); f_tvm_register_system_symbol_ = llvm::Function::Create( llvm::FunctionType::get(t_int_, {llvmGetPointerTo(t_char_, 0), t_void_p_}, false), - llvm::Function::ExternalLinkage, "TVMBackendRegisterSystemLibSymbol", module_.get()); + llvm::Function::ExternalLinkage, "TVMFFIEnvRegisterSystemLibSymbol", module_.get()); } else { f_tvm_register_system_symbol_ = nullptr; } @@ -236,11 +236,11 @@ void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) { // Create wrapper function llvm::Function* wrapper_func = llvm::Function::Create(target_func->getFunctionType(), llvm::Function::WeakAnyLinkage, - runtime::symbol::tvm_ffi_main, module_.get()); + ffi::symbol::tvm_ffi_main, module_.get()); // Set attributes (Windows comdat, DLL export, etc.) if (llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) { - llvm::Comdat* comdat = module_->getOrInsertComdat(runtime::symbol::tvm_ffi_main); + llvm::Comdat* comdat = module_->getOrInsertComdat(ffi::symbol::tvm_ffi_main); comdat->setSelectionKind(llvm::Comdat::Any); wrapper_func->setComdat(comdat); } @@ -454,8 +454,7 @@ llvm::Value* CodeGenCPU::GetContextPtr(llvm::GlobalVariable* gv) { } void CodeGenCPU::InitGlobalContext(bool dynamic_lookup) { - std::string ctx_symbol = - system_lib_prefix_.value_or("") + tvm::runtime::symbol::tvm_ffi_library_ctx; + std::string ctx_symbol = system_lib_prefix_.value_or("") + ffi::symbol::tvm_ffi_library_ctx; // Module context gv_mod_ctx_ = InitContextPtr(t_void_p_, ctx_symbol); // Register back the locations. diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index 6f90da3d8aea..67fccd8b073a 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -440,7 +440,7 @@ void ProcessLLVMOptions(const std::vector& llvm_vec) { } } // namespace -runtime::Module BuildHexagon(IRModule mod, Target target) { +ffi::Module BuildHexagon(IRModule mod, Target target) { LLVMInstance llvm_instance; With llvm_target(llvm_instance, target); diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index a6b70ad39a32..a1c967e644cb 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -322,7 +322,7 @@ int GetCUDAComputeVersion(const Target& target) { return std::stoi(sm_version.substr(3)); } -runtime::Module BuildNVPTX(IRModule mod, Target target) { +ffi::Module BuildNVPTX(IRModule mod, Target target) { LLVMInstance llvm_instance; With llvm_target(llvm_instance, target); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index dd9622999bd2..f90729a45f06 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -57,11 +57,11 @@ #include #include #include +#include #include #include #include #include -#include #include #include #include @@ -77,7 +77,6 @@ #include #include "../../runtime/file_utils.h" -#include "../../runtime/library_module.h" #include "codegen_blob.h" #include "codegen_cpu.h" #include "codegen_llvm.h" @@ -90,29 +89,29 @@ using ffi::Any; using ffi::Function; using ffi::PackedArgs; -class LLVMModuleNode final : public runtime::ModuleNode { +class LLVMModuleNode final : public ffi::ModuleObj { public: ~LLVMModuleNode(); - const char* type_key() const final { return "llvm"; } + const char* kind() const final { return "llvm"; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + Optional GetFunction(const String& name) final; /*! \brief Get the property of the runtime module .*/ // TODO(tvm-team): Make it serializable int GetPropertyMask() const override { - return runtime::ModulePropertyMask::kRunnable | runtime::ModulePropertyMask::kDSOExportable; + return ffi::Module::kRunnable | ffi::Module::kCompilationExportable; } - void SaveToFile(const String& file_name, const String& format) final; - void SaveToBinary(dmlc::Stream* stream) final; - String GetSource(const String& format) final; + void WriteToFile(const String& file_name, const String& format) const final; + ffi::Bytes SaveToBytes() const final; + String InspectSource(const String& format) const final; void Init(const IRModule& mod, const Target& target); void Init(std::unique_ptr module, std::unique_ptr llvm_instance); void LoadIR(const std::string& file_name); - bool ImplementsFunction(const String& name, bool query_imports) final; + bool ImplementsFunction(const String& name) final; void SetJITEngine(const std::string& jit_engine) { jit_engine_ = jit_engine; } @@ -156,8 +155,8 @@ LLVMModuleNode::~LLVMModuleNode() { module_owning_ptr_.reset(); } -ffi::Function LLVMModuleNode::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +Optional LLVMModuleNode::GetFunction(const String& name) { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (name == "__tvm_is_system_module") { bool flag = (module_->getFunction("__tvm_module_startup") != nullptr); return ffi::Function([flag](ffi::PackedArgs args, ffi::Any* rv) { *rv = flag; }); @@ -174,9 +173,9 @@ ffi::Function LLVMModuleNode::GetFunction(const String& name, return ffi::Function( [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->function_names_; }); } else if (name == "get_symbol") { - return ffi::Function(nullptr); + return std::nullopt; } else if (name == "get_const_vars") { - return ffi::Function(nullptr); + return std::nullopt; } else if (name == "_get_target_string") { std::string target_string = LLVMTarget::GetTargetMetadata(*module_); return ffi::Function( @@ -191,8 +190,13 @@ ffi::Function LLVMModuleNode::GetFunction(const String& name, TVMFFISafeCallType faddr; With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); faddr = reinterpret_cast(GetFunctionAddr(name, *llvm_target)); - if (faddr == nullptr) return ffi::Function(); - return tvm::runtime::WrapFFIFunction(faddr, sptr_to_self); + if (faddr == nullptr) return std::nullopt; + ffi::Module self_strong_ref = GetRef(this); + return ffi::Function::FromPacked([faddr, self_strong_ref](ffi::PackedArgs args, ffi::Any* rv) { + TVM_FFI_ICHECK_LT(rv->type_index(), ffi::TypeIndex::kTVMFFIStaticObjectBegin); + TVM_FFI_CHECK_SAFE_CALL((*faddr)(nullptr, reinterpret_cast(args.data()), + args.size(), reinterpret_cast(rv))); + }); } namespace { @@ -231,7 +235,7 @@ bool LLVMAddPassesToEmitFile(llvm::TargetMachine* tm, llvm::legacy::PassManager* } // namespace -void LLVMModuleNode::SaveToFile(const String& file_name_str, const String& format) { +void LLVMModuleNode::WriteToFile(const String& file_name_str, const String& format) const { // CHECK(imports_.empty()) << "SaveToFile does not handle imported modules"; std::string file_name = file_name_str; std::string fmt = runtime::GetFileFormat(file_name, format); @@ -266,11 +270,11 @@ void LLVMModuleNode::SaveToFile(const String& file_name_str, const String& forma dest.close(); } -void LLVMModuleNode::SaveToBinary(dmlc::Stream* stream) { - LOG(FATAL) << "LLVMModule: SaveToBinary not supported"; +ffi::Bytes LLVMModuleNode::SaveToBytes() const { + LOG(FATAL) << "LLVMModule: SaveToBytes not supported"; } -String LLVMModuleNode::GetSource(const String& format) { +String LLVMModuleNode::InspectSource(const String& format) const { std::string fmt = runtime::GetFileFormat("", format); std::string type_str; llvm::SmallString<256> str; @@ -381,7 +385,7 @@ void LLVMModuleNode::LoadIR(const std::string& file_name) { Init(std::move(module), std::move(llvm_instance)); } -bool LLVMModuleNode::ImplementsFunction(const String& name, bool query_imports) { +bool LLVMModuleNode::ImplementsFunction(const String& name) { return std::find(function_names_.begin(), function_names_.end(), name) != function_names_.end(); } @@ -434,12 +438,16 @@ void LLVMModuleNode::InitMCJIT() { // run ctors mcjit_ee_->runStaticConstructorsDestructors(false); - if (void** ctx_addr = reinterpret_cast( - GetGlobalAddr(runtime::symbol::tvm_ffi_library_ctx, *llvm_target))) { + if (void** ctx_addr = + reinterpret_cast(GetGlobalAddr(ffi::symbol::tvm_ffi_library_ctx, *llvm_target))) { *ctx_addr = this; } - runtime::InitContextFunctions( - [this, &llvm_target](const char* name) { return GetGlobalAddr(name, *llvm_target); }); + + ffi::Module::VisitContextSymbols([this, &llvm_target](const String& name, void* symbol) { + if (void** ctx_addr = reinterpret_cast(GetGlobalAddr(name, *llvm_target))) { + *ctx_addr = symbol; + } + }); // There is a problem when a JITed function contains a call to a runtime function. // The runtime function (e.g. __truncsfhf2) may not be resolved, and calling it will // lead to a runtime crash. @@ -575,12 +583,15 @@ void LLVMModuleNode::InitORCJIT() { err = ctorRunner.run(); ICHECK(!err) << llvm::toString(std::move(err)); - if (void** ctx_addr = reinterpret_cast( - GetGlobalAddr(runtime::symbol::tvm_ffi_library_ctx, *llvm_target))) { + if (void** ctx_addr = + reinterpret_cast(GetGlobalAddr(ffi::symbol::tvm_ffi_library_ctx, *llvm_target))) { *ctx_addr = this; } - runtime::InitContextFunctions( - [this, &llvm_target](const char* name) { return GetGlobalAddr(name, *llvm_target); }); + ffi::Module::VisitContextSymbols([this, &llvm_target](const String& name, void* symbol) { + if (void** ctx_addr = reinterpret_cast(GetGlobalAddr(name, *llvm_target))) { + *ctx_addr = symbol; + } + }); } bool LLVMModuleNode::IsCompatibleWithHost(const llvm::TargetMachine* tm) const { @@ -638,13 +649,13 @@ static void LLVMReflectionRegister() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("target.build.llvm", - [](IRModule mod, Target target) -> runtime::Module { + [](IRModule mod, Target target) -> ffi::Module { auto n = make_object(); n->Init(mod, target); - return runtime::Module(n); + return ffi::Module(n); }) .def("codegen.LLVMModuleCreate", - [](std::string target_str, std::string module_name) -> runtime::Module { + [](std::string target_str, std::string module_name) -> ffi::Module { auto llvm_instance = std::make_unique(); With llvm_target(*llvm_instance, target_str); auto n = make_object(); @@ -659,7 +670,7 @@ static void LLVMReflectionRegister() { module->setDataLayout(llvm_target->GetOrCreateTargetMachine()->createDataLayout()); n->Init(std::move(module), std::move(llvm_instance)); n->SetJITEngine(llvm_target->GetJITEngine()); - return runtime::Module(n); + return ffi::Module(n); }) .def("target.llvm_lookup_intrinsic_id", [](std::string name) -> int64_t { @@ -765,12 +776,12 @@ static void LLVMReflectionRegister() { return llvm_target.TargetHasCPUFeature(feature); }) .def("target.llvm_version_major", []() -> int { return TVM_LLVM_VERSION / 10; }) - .def("runtime.module.loadfile_ll", - [](std::string filename, std::string fmt) -> runtime::Module { + .def("ffi.Module.load_from_file.ll", + [](std::string filename, std::string fmt) -> ffi::Module { auto n = make_object(); n->SetJITEngine("orcjit"); n->LoadIR(filename); - return runtime::Module(n); + return ffi::Module(n); }) .def("codegen.llvm_target_enabled", [](std::string target_str) -> bool { @@ -781,7 +792,7 @@ static void LLVMReflectionRegister() { }) .def("codegen.codegen_blob", [](std::string data, bool system_lib, std::string llvm_target_string, - std::string c_symbol_prefix) -> runtime::Module { + std::string c_symbol_prefix) -> ffi::Module { auto n = make_object(); auto llvm_instance = std::make_unique(); With llvm_target(*llvm_instance, llvm_target_string); @@ -789,7 +800,7 @@ static void LLVMReflectionRegister() { CodeGenBlob(data, system_lib, llvm_target.get(), c_symbol_prefix); n->Init(std::move(blob), std::move(llvm_instance)); n->SetJITEngine(llvm_target->GetJITEngine()); - return runtime::Module(n); + return ffi::Module(n); }); } diff --git a/src/target/llvm/llvm_module.h b/src/target/llvm/llvm_module.h index 2070d7da3e0c..75897e539f85 100644 --- a/src/target/llvm/llvm_module.h +++ b/src/target/llvm/llvm_module.h @@ -35,7 +35,7 @@ namespace tvm { namespace codegen { -runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target); +ffi::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target); } // namespace codegen } // namespace tvm diff --git a/src/target/opt/build_cuda_off.cc b/src/target/opt/build_cuda_off.cc index 893eb67a268f..c0b494ff619c 100644 --- a/src/target/opt/build_cuda_off.cc +++ b/src/target/opt/build_cuda_off.cc @@ -24,11 +24,11 @@ namespace tvm { namespace runtime { -Module CUDAModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, - std::string cuda_source) { +ffi::Module CUDAModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string cuda_source) { LOG(FATAL) << "CUDA is not enabled"; - return Module(); + TVM_FFI_UNREACHABLE(); } } // namespace runtime } // namespace tvm diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index da8896bf4826..6072a483877c 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -126,7 +126,7 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) { return ptx; } -runtime::Module BuildCUDA(IRModule mod, Target target) { +ffi::Module BuildCUDA(IRModule mod, Target target) { bool output_ssa = false; CodeGenCUDA cg; cg.Init(output_ssa); diff --git a/src/target/opt/build_hexagon_off.cc b/src/target/opt/build_hexagon_off.cc index 2ce5cdb51f5d..696ca6399560 100644 --- a/src/target/opt/build_hexagon_off.cc +++ b/src/target/opt/build_hexagon_off.cc @@ -22,9 +22,10 @@ namespace tvm { namespace runtime { -Module HexagonModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string asm_str, - std::string obj_str, std::string ir_str, std::string bc_str) { +ffi::Module HexagonModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string asm_str, std::string obj_str, std::string ir_str, + std::string bc_str) { LOG(WARNING) << "Hexagon runtime is not enabled, return a source module..."; return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "hex"); } diff --git a/src/target/opt/build_metal_off.cc b/src/target/opt/build_metal_off.cc index 555aa5002f98..4200a35cbb58 100644 --- a/src/target/opt/build_metal_off.cc +++ b/src/target/opt/build_metal_off.cc @@ -26,9 +26,9 @@ namespace tvm { namespace runtime { -Module MetalModuleCreate(std::unordered_map smap, - std::unordered_map fmap, std::string fmt, - std::string source) { +ffi::Module MetalModuleCreate(std::unordered_map smap, + std::unordered_map fmap, std::string fmt, + std::string source) { LOG(WARNING) << "Metal runtime not enabled, return a source module..."; return codegen::DeviceSourceModuleCreate(source, fmt, fmap, "metal"); } diff --git a/src/target/opt/build_opencl_off.cc b/src/target/opt/build_opencl_off.cc index 9e368d5599cf..797aa3ef8d38 100644 --- a/src/target/opt/build_opencl_off.cc +++ b/src/target/opt/build_opencl_off.cc @@ -26,16 +26,17 @@ namespace tvm { namespace runtime { -Module OpenCLModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string source) { +ffi::Module OpenCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string source) { return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "opencl"); } -Module OpenCLModuleCreate(const std::unordered_map& shaders, - const std::string& spirv_text, - std::unordered_map fmap) { +ffi::Module OpenCLModuleCreate(const std::unordered_map& shaders, + const std::string& spirv_text, + std::unordered_map fmap) { LOG(FATAL) << "OpenCLModuleCreate is called but OpenCL is not enabled."; - return Module(); + TVM_FFI_UNREACHABLE(); } } // namespace runtime diff --git a/src/target/opt/build_rocm_off.cc b/src/target/opt/build_rocm_off.cc index 476e5a88fc6f..f161faa9f648 100644 --- a/src/target/opt/build_rocm_off.cc +++ b/src/target/opt/build_rocm_off.cc @@ -26,9 +26,9 @@ namespace tvm { namespace runtime { -Module ROCMModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string rocm_source, - std::string assembly) { +ffi::Module ROCMModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string rocm_source, std::string assembly) { LOG(WARNING) << "ROCM runtime is not enabled, return a source module..."; auto fget_source = [rocm_source, assembly](const std::string& format) { if (format.length() == 0) return assembly; diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 020054b3e1fc..e18ba0128d6b 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -22,8 +22,8 @@ */ #include "codegen_c_host.h" +#include #include -#include #include #include @@ -54,7 +54,7 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_d } void CodeGenCHost::InitGlobalContext() { - decl_stream << "void* " << tvm::runtime::symbol::tvm_ffi_library_ctx << " = NULL;\n"; + decl_stream << "void* " << ffi::symbol::tvm_ffi_library_ctx << " = NULL;\n"; } void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; } @@ -77,11 +77,11 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func, << "CodeGenCHost: The entry func must have the global_symbol attribute, " << "but function " << gvar << " only has attributes " << func->attrs; - function_names_.push_back(runtime::symbol::tvm_ffi_main); + function_names_.push_back(ffi::symbol::tvm_ffi_main); stream << "// CodegenC: NOTE: Auto-generated entry function\n"; PrintFuncPrefix(stream); PrintType(func->ret_type, stream); - stream << " " << tvm::runtime::symbol::tvm_ffi_main + stream << " " << ffi::symbol::tvm_ffi_main << "(void* self, void* args,int num_args, void* result) {\n"; stream << " return " << global_symbol.value() << "(self, args, num_args, result);\n"; stream << "}\n"; @@ -355,7 +355,7 @@ inline void CodeGenCHost::PrintTernaryCondExpr(const T* op, const char* compare, << "? (" << a_id << ") : (" << b_id << "))"; } -runtime::Module BuildCHost(IRModule mod, Target target) { +ffi::Module BuildCHost(IRModule mod, Target target) { bool output_ssa = false; bool emit_asserts = false; bool emit_fwd_func_decl = true; diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index ffb1737a7063..dc019c28a7a0 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -431,7 +431,7 @@ void CodeGenMetal::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NO os << temp.str(); } -runtime::Module BuildMetal(IRModule mod, Target target) { +ffi::Module BuildMetal(IRModule mod, Target target) { bool output_ssa = false; mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 2645423affe3..1342464665f3 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -633,7 +633,7 @@ void CodeGenOpenCL::SetTextureScope( } } -runtime::Module BuildOpenCL(IRModule mod, Target target) { +ffi::Module BuildOpenCL(IRModule mod, Target target) { #if TVM_ENABLE_SPIRV Optional device = target->GetAttr("device"); if (device && device.value() == "spirv") { diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index a416e3fcae31..f077f8c3a83b 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -140,7 +140,7 @@ class CodeGenSourceBase { * \param code The code to be viewed. * \param fmt The code. format. */ -runtime::Module SourceModuleCreate(std::string code, std::string fmt); +ffi::Module SourceModuleCreate(std::string code, std::string fmt); /*! * \brief Create a C source module for viewing and compiling GCC code. @@ -150,9 +150,9 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt); * \param const_vars. The constant variables that the c source module needs. * \return The created module. */ -runtime::Module CSourceModuleCreate(const String& code, const String& fmt, - const Array& func_names, - const Array& const_vars = {}); +ffi::Module CSourceModuleCreate(const String& code, const String& fmt, + const Array& func_names, + const Array& const_vars = {}); /*! * \brief Wrap the submodules in a metadata module. @@ -163,9 +163,9 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, * \param target The target that all the modules are compiled for * \return The wrapped module. */ -runtime::Module CreateMetadataModule( - const std::unordered_map& params, runtime::Module target_module, - const Array& ext_modules, Target target); +ffi::Module CreateMetadataModule(const std::unordered_map& params, + ffi::Module target_module, const Array& ext_modules, + Target target); /*! * \brief Create a source module for viewing and limited saving for device. @@ -175,7 +175,7 @@ runtime::Module CreateMetadataModule( * \param type_key The type_key of the runtime module of this source code * \param fget_source a closure to replace default get source behavior. */ -runtime::Module DeviceSourceModuleCreate( +ffi::Module DeviceSourceModuleCreate( std::string data, std::string fmt, std::unordered_map fmap, std::string type_key, std::function fget_source = nullptr); diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index f5bfd80fee25..28d158c3c21e 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -22,6 +22,7 @@ */ #include "codegen_webgpu.h" +#include #include #include #include @@ -705,27 +706,30 @@ void CodeGenWebGPU::VisitStmt_(const WhileNode* op) { //------------------------------------------------- // WebGPUSourceModule to enable export //------------------------------------------------- -class WebGPUSourceModuleNode final : public runtime::ModuleNode { +class WebGPUSourceModuleNode final : public ffi::ModuleObj { public: explicit WebGPUSourceModuleNode(std::unordered_map smap, std::unordered_map fmap) : smap_(smap), fmap_(fmap) {} - const char* type_key() const final { return "webgpu"; } + const char* kind() const final { return "webgpu"; } /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const final { return runtime::ModulePropertyMask::kBinarySerializable; } + int GetPropertyMask() const final { return ffi::Module::kBinarySerializable; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + Optional GetFunction(const String& name) final { LOG(FATAL) << "WebGPUSourceModule is not directly runnable, export and run through tvmjs"; - return ffi::Function(nullptr); } - void SaveToBinary(dmlc::Stream* stream) final { + ffi::Bytes SaveToBytes() const final { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; stream->Write(fmap_); stream->Write(smap_); + return ffi::Bytes(buffer); } - String GetSource(const String& format) final { + String InspectSource(const String& format) const final { if (format == "func_info") { std::ostringstream stream; dmlc::JSONWriter(&stream).Write(fmap_); @@ -749,7 +753,7 @@ class WebGPUSourceModuleNode final : public runtime::ModuleNode { //------------------------------------------------- // Build logic. //------------------------------------------------- -runtime::Module BuildWebGPU(IRModule mod, Target target) { +ffi::Module BuildWebGPU(IRModule mod, Target target) { mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); bool output_ssa = false; bool skip_readonly_decl = false; @@ -777,7 +781,7 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) { } auto n = make_object(smap, fmap); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index b3c5ff311a3c..1350357d866c 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -23,9 +23,9 @@ */ #include +#include #include #include -#include #include #include @@ -51,40 +51,43 @@ using runtime::GetMetaFilePath; using runtime::SaveBinaryToFile; // Simulator function -class SourceModuleNode : public runtime::ModuleNode { +class SourceModuleNode : public ffi::ModuleObj { public: SourceModuleNode(std::string code, std::string fmt) : code_(code), fmt_(fmt) {} - const char* type_key() const final { return "source"; } + const char* kind() const final { return "source"; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + Optional GetFunction(const String& name) final { LOG(FATAL) << "Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; - return ffi::Function(); } - String GetSource(const String& format) final { return code_; } + String InspectSource(const String& format) const final { return code_; } - String GetFormat() override { return fmt_; } + Array GetWriteFormats() const override { return {fmt_}; } protected: std::string code_; std::string fmt_; }; -runtime::Module SourceModuleCreate(std::string code, std::string fmt) { +ffi::Module SourceModuleCreate(std::string code, std::string fmt) { auto n = make_object(code, fmt); - return runtime::Module(n); + return ffi::Module(n); } // Simulator function -class CSourceModuleNode : public runtime::ModuleNode { +class CSourceModuleNode : public ffi::ModuleObj { public: CSourceModuleNode(const std::string& code, const std::string& fmt, const Array& func_names, const Array& const_vars) - : code_(code), fmt_(fmt), const_vars_(const_vars), func_names_(func_names) {} - const char* type_key() const final { return "c"; } + : code_(code), fmt_(fmt), const_vars_(const_vars), func_names_(func_names) { + if (fmt_.empty()) fmt_ = "c"; + } + + const char* kind() const final { return "c"; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + Optional GetFunction(const String& name) final { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); // Currently c-source module is used as demonstration purposes with binary metadata module // that expects get_symbol interface. When c-source module is used as external module, it // will only contain one function. However, when its used as an internal module (e.g., target @@ -103,11 +106,14 @@ class CSourceModuleNode : public runtime::ModuleNode { } } - String GetSource(const String& format) final { return code_; } + String InspectSource(const String& format) const final { return code_; } - String GetFormat() override { return fmt_; } + Array GetWriteFormats() const override { return {fmt_}; } - void SaveToBinary(dmlc::Stream* stream) final { + ffi::Bytes SaveToBytes() const final { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; stream->Write(code_); stream->Write(fmt_); @@ -117,10 +123,12 @@ class CSourceModuleNode : public runtime::ModuleNode { for (auto const_var : const_vars_) const_vars.push_back(const_var); stream->Write(func_names); stream->Write(const_vars); + return ffi::Bytes(buffer); } - static runtime::Module LoadFromBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); + static ffi::Module LoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; std::string code, fmt; ICHECK(stream->Read(&code)) << "Loading code failed"; @@ -137,10 +145,10 @@ class CSourceModuleNode : public runtime::ModuleNode { for (auto const_var : tmp_const_vars) const_vars.push_back(String(const_var)); auto n = make_object(code, fmt, func_names, const_vars); - return runtime::Module(n); + return ffi::Module(n); } - void SaveToFile(const String& file_name, const String& format) final { + void WriteToFile(const String& file_name, const String& format) const final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); if (fmt == "c" || fmt == "cc" || fmt == "cpp" || fmt == "cu") { @@ -152,11 +160,10 @@ class CSourceModuleNode : public runtime::ModuleNode { } int GetPropertyMask() const override { - return runtime::ModulePropertyMask::kBinarySerializable | - runtime::ModulePropertyMask::kDSOExportable; + return ffi::Module::kBinarySerializable | ffi::Module::kCompilationExportable; } - bool ImplementsFunction(const String& name, bool query_imports) final { + bool ImplementsFunction(const String& name) final { return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end(); } @@ -167,17 +174,16 @@ class CSourceModuleNode : public runtime::ModuleNode { Array func_names_; }; -runtime::Module CSourceModuleCreate(const String& code, const String& fmt, - const Array& func_names, - const Array& const_vars) { +ffi::Module CSourceModuleCreate(const String& code, const String& fmt, + const Array& func_names, const Array& const_vars) { auto n = make_object(code.operator std::string(), fmt.operator std::string(), func_names, const_vars); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("runtime.module.loadbinary_c", CSourceModuleNode::LoadFromBinary); + refl::GlobalDef().def("ffi.Module.load_from_bytes.c", CSourceModuleNode::LoadFromBytes); }); /*! @@ -197,20 +203,19 @@ class ConcreteCodegenSourceBase : public CodeGenSourceBase { }; // supports limited save without cross compile -class DeviceSourceModuleNode final : public runtime::ModuleNode { +class DeviceSourceModuleNode final : public ffi::ModuleObj { public: DeviceSourceModuleNode(std::string data, std::string fmt, std::unordered_map fmap, std::string type_key, std::function fget_source) : data_(data), fmt_(fmt), fmap_(fmap), type_key_(type_key), fget_source_(fget_source) {} - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + Optional GetFunction(const String& name) final { LOG(FATAL) << "Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; - return ffi::Function(); } - String GetSource(const String& format) final { + String InspectSource(const String& format) const final { if (fget_source_ != nullptr) { return fget_source_(format); } else { @@ -218,11 +223,11 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode { } } - const char* type_key() const final { return type_key_.c_str(); } + const char* kind() const final { return type_key_.c_str(); } /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const final { return runtime::ModulePropertyMask::kBinarySerializable; } + int GetPropertyMask() const final { return ffi::Module::kBinarySerializable; } - void SaveToFile(const String& file_name, const String& format) final { + void WriteToFile(const String& file_name, const String& format) const final { std::string fmt = GetFileFormat(file_name, format); ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); @@ -230,10 +235,14 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode { SaveBinaryToFile(file_name, data_); } - void SaveToBinary(dmlc::Stream* stream) final { + ffi::Bytes SaveToBytes() const final { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; stream->Write(fmt_); stream->Write(fmap_); stream->Write(data_); + return ffi::Bytes(buffer); } private: @@ -244,11 +253,12 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode { std::function fget_source_; }; -runtime::Module DeviceSourceModuleCreate( - std::string data, std::string fmt, std::unordered_map fmap, - std::string type_key, std::function fget_source) { +ffi::Module DeviceSourceModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string type_key, + std::function fget_source) { auto n = make_object(data, fmt, fmap, type_key, fget_source); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index 26ecffcc6bd3..bd44607a98eb 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -32,7 +32,7 @@ namespace tvm { namespace codegen { -runtime::Module BuildSPIRV(IRModule mod, Target target) { +ffi::Module BuildSPIRV(IRModule mod, Target target) { auto [smap, spirv_text] = LowerToSPIRV(mod, target); return runtime::VulkanModuleCreate(smap, ExtractFuncInfo(mod), spirv_text); } diff --git a/tests/cpp-runtime/opencl/opencl_compile_to_bin.cc b/tests/cpp-runtime/opencl/opencl_compile_to_bin.cc index b8e5b90ece7c..1097a21128e1 100644 --- a/tests/cpp-runtime/opencl/opencl_compile_to_bin.cc +++ b/tests/cpp-runtime/opencl/opencl_compile_to_bin.cc @@ -184,9 +184,8 @@ TEST_F(OpenCLCompileBin, SourceVsBinaryCompilationPerf) { module.InstallKernel(m_workspace, m_workspace->GetThreadEntry(), m_kernelNames[i], e); } Timestamp comp_end = std::chrono::high_resolution_clock::now(); - auto get_pre_compiled_f = module.GetFunction("opencl.GetPreCompiledPrograms", - tvm::ffi::GetObjectPtr(&module)); - bytes = get_pre_compiled_f().cast(); + auto get_pre_compiled_f = module.GetFunction("opencl.GetPreCompiledPrograms").value(); + bytes = get_pre_compiled_f().cast(); std::chrono::duration duration = std::chrono::duration_cast(comp_end - comp_start); compileFromSourceTimeMS = duration.count() * 1e-6; @@ -195,8 +194,7 @@ TEST_F(OpenCLCompileBin, SourceVsBinaryCompilationPerf) { { OpenCLModuleNode module(m_dataSrc, "cl", m_fmap, std::string()); module.Init(); - module.GetFunction("opencl.SetPreCompiledPrograms", - GetObjectPtr(&module))(tvm::String(bytes)); + module.GetFunction("opencl.SetPreCompiledPrograms").value()(tvm::String(bytes)); Timestamp comp_start = std::chrono::high_resolution_clock::now(); for (size_t i = 0; i < m_kernelNames.size(); ++i) { OpenCLModuleNode::KTRefEntry e = {i, 1}; diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 2c8f185d8ecd..90ad1d65c7aa 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -48,7 +48,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and mul instructions using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"mul\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -80,7 +80,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and add instructions using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"add\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -112,7 +112,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and sub instructions using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"sub\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -145,7 +145,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C, D])) # Verify we see SVE load instructions and either mad or mla instructions using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"mad|mla\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -177,7 +177,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and cmgt + sel instructions or a max instruction, all using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) compare = re.findall( r"cmgt\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -213,7 +213,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and cmgt + sel instructions or a min instruction, all using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) compare = re.findall( r"cmgt\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -249,7 +249,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and div instructions using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"div\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -280,7 +280,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and mls instructions using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"mls\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -312,7 +312,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and cmpeq or cmeq instructions using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"cm(p)?eq\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -344,7 +344,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and cmpgt, cmgt, cmpne or cmne instructions, all using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"cm(p)?(gt|ne)\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -375,7 +375,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and orr instructions using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"orr\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -406,7 +406,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and and instructions using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"and\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -436,7 +436,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, C])) # Verify we see SVE load instructions and eor instructions using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"eor\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -471,7 +471,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see gather instructions in the assembly - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) assert len(loads) > 0 @@ -503,7 +503,7 @@ def test_vscale_range_function_attribute(mattr, expect_attr): f = tvm.tir.build(te.create_prim_func([A, C])) # Check if the vscale_range() attribute exists - ll = f.get_source("ll") + ll = f.inspect_source("ll") attr = re.findall(rf".*vscale_range\(\d+,\d+\)*.", ll) if expect_attr: diff --git a/tests/python/codegen/test_target_codegen_arm.py b/tests/python/codegen/test_target_codegen_arm.py index d22e528770b3..e6d0c70f8734 100644 --- a/tests/python/codegen/test_target_codegen_arm.py +++ b/tests/python/codegen/test_target_codegen_arm.py @@ -30,7 +30,7 @@ def check_correct_assembly(type, elements, counts): sch.vectorize(sch.get_loops("B")[0]) f = tvm.tir.build(sch.mod, target=target) # Verify we see the correct number of vpaddl and vcnt instructions in the assembly - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") matches = re.findall("vpaddl", assembly) assert len(matches) == counts matches = re.findall("vcnt", assembly) @@ -61,7 +61,7 @@ def check_correct_assembly(N): f = tvm.tir.build(sch.mod, target=target) # Verify we see the correct number of vmlal.s16 instructions - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") matches = re.findall("vmlal.s16", assembly) assert len(matches) == N // 4 @@ -85,7 +85,7 @@ def check_broadcast_correct_assembly(N): f = tvm.tir.build(sch.mod, target=target) # Verify we see the correct number of vmlal.s16 instructions - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") matches = re.findall("vmlal.s16", assembly) assert len(matches) == N // 4 diff --git a/tests/python/codegen/test_target_codegen_c_host.py b/tests/python/codegen/test_target_codegen_c_host.py index af94cae71f1c..3c80cfbeb0b4 100644 --- a/tests/python/codegen/test_target_codegen_c_host.py +++ b/tests/python/codegen/test_target_codegen_c_host.py @@ -192,7 +192,7 @@ def subroutine(A_data: T.handle("float32")): "subroutine" not in func_names ), "Internal function should not be listed in available functions." - source = built.get_source() + source = built.inspect_source() assert ( source.count("main(void*") == 2 ), "Expected two occurrences, for forward-declaration and definition" diff --git a/tests/python/codegen/test_target_codegen_cross_llvm.py b/tests/python/codegen/test_target_codegen_cross_llvm.py index c126e531090e..9ae516c7de30 100644 --- a/tests/python/codegen/test_target_codegen_cross_llvm.py +++ b/tests/python/codegen/test_target_codegen_cross_llvm.py @@ -51,7 +51,7 @@ def build_i386(): target = "llvm -mtriple=i386-pc-linux-gnu" f = tvm.tir.build(sch.mod, target=target) path = temp.relpath("myadd.o") - f.save(path) + f.write_to_file(path) verify_elf(path, 0x03) def build_arm(): @@ -62,10 +62,10 @@ def build_arm(): temp = utils.tempdir() f = tvm.tir.build(sch.mod, target=target) path = temp.relpath("myadd.o") - f.save(path) + f.write_to_file(path) verify_elf(path, 0x28) asm_path = temp.relpath("myadd.asm") - f.save(asm_path) + f.write_to_file(asm_path) # Do a RPC verification, launch kernel on Arm Board if available. host = os.environ.get("TVM_RPC_ARM_HOST", None) remote = None diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index a304cb1e41c7..fb9c47410fea 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -663,7 +663,7 @@ def build(A, C, N, C_N): f = tvm.tir.build(sch.mod, target="cuda") - kernel_source = f.imported_modules[0].get_source() + kernel_source = f.imports[0].inspect_source() dev = tvm.cuda() a_data = np.arange(0, N).astype(A.dtype) a = tvm.nd.array(a_data, dev) @@ -774,7 +774,7 @@ def main(A_ptr: T.handle): A[0] = ((float)(*(double *)(&(A_map)))); } }""".strip() - in mod.mod.imported_modules[0].get_source() + in mod.mod.imports[0].inspect_source() ) @@ -797,7 +797,7 @@ def main( C[bx, tx] = Module.add(A[bx, tx], B[bx, tx]) lib = tvm.compile(Module, target="cuda") - cuda_code = lib.mod.imported_modules[0].get_source() + cuda_code = lib.mod.imports[0].inspect_source() assert 'extern "C" __device__ float add(float a, float b) {\n return (a + b);\n}' in cuda_code @@ -827,7 +827,7 @@ def main( # in order to avoid checking a function is host or device based on the "cpu" substring. target = tvm.target.Target({"kind": "cuda", "mcpu": "dummy_mcpu"}, host="c") lib = tvm.compile(Module, target=target) - cuda_code = lib.mod.imported_modules[0].get_source() + cuda_code = lib.mod.imports[0].inspect_source() assert 'extern "C" __device__ int add(int a, int b) {\n return (a + b);\n}' in cuda_code # Run a simple test @@ -854,7 +854,7 @@ def main(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")): B[bx, tx] = A[bx, tx] lib = tvm.compile(Module, target="cuda") - cuda_code = lib.mod.imported_modules[0].get_source() + cuda_code = lib.mod.imports[0].inspect_source() assert "return;" in cuda_code diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index aa9080a48882..c0b6130bcb80 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -71,7 +71,7 @@ def add( target = "cuda" fadd = tvm.tir.build(sch.mod, target=target) - cuda_src = fadd.imported_modules[0].get_source() + cuda_src = fadd.imports[0].inspect_source() assert nv_dtype in cuda_src, f"{nv_dtype} datatype not found in generated CUDA" dev = tvm.device(target, 0) @@ -190,7 +190,7 @@ def add( target = "cuda" fadd = tvm.tir.build(sch.mod, target=target) - cuda_src = fadd.imported_modules[0].get_source() + cuda_src = fadd.imports[0].inspect_source() dev = tvm.device(target, 0) if "x" in native_dtype: @@ -710,7 +710,7 @@ def print_cuda(target, mod, name=None): if name: mod = mod[name] f = tvm.tir.build(mod, target=target) - cuda_src = f.imported_modules[0].get_source() + cuda_src = f.imports[0].inspect_source() print(cuda_src) print_cuda(target, dequant_mod, name="dequant") diff --git a/tests/python/codegen/test_target_codegen_hexagon.py b/tests/python/codegen/test_target_codegen_hexagon.py index c0665ce316ad..f14005ad9d0b 100644 --- a/tests/python/codegen/test_target_codegen_hexagon.py +++ b/tests/python/codegen/test_target_codegen_hexagon.py @@ -46,7 +46,7 @@ def check_add(): C = tvm.te.compute((128,), lambda i: A[i] + B[i], name="C") mod = tvm.IRModule.from_expr(te.create_prim_func([C, A, B])) hexm = tvm.compile(mod, target=tvm.target.Target(target, target)) - asm = hexm.get_source("s") + asm = hexm.inspect_source("s") vadds = re.findall(r"v[0-9]+.b = vadd\(v[0-9]+.b,v[0-9]+.b\)", asm) assert vadds # Check that it's non-empty @@ -61,7 +61,7 @@ def test_llvm_target_features(): C = tvm.te.compute((128,), lambda i: A[i] + 1, name="C") mod = tvm.IRModule.from_expr(te.create_prim_func([C, A]).with_attr("global_symbol", "add_one")) m = tvm.compile(mod, target=tvm.target.Target(target, target)) - llvm_ir = m.get_source("ll") + llvm_ir = m.inspect_source("ll") # Make sure we find +hvx-length128b in "attributes". fs = re.findall(r"attributes.*\+hvx-length128b", llvm_ir) assert fs # Check that it's non-empty diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index 15c030aeacf2..953adf78b342 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -454,7 +454,7 @@ def test_alignment(): # Build with name f = tvm.tir.build(sch.mod, target="llvm") - lines = f.get_source().split("\n") + lines = f.inspect_source().split("\n") # Check alignment on load/store. for l in lines: @@ -702,7 +702,7 @@ def check_llvm_object(): m = tvm.compile(mod, target="llvm") temp = utils.tempdir() o_path = temp.relpath("temp.o") - m.save(o_path) + m.write_to_file(o_path) import shutil import subprocess import sys @@ -738,7 +738,7 @@ def check_llvm_ir(): } ) m = tvm.tir.build(mod, target="llvm -mtriple=aarch64-linux-gnu") - ll = m.get_source("ll") + ll = m.inspect_source("ll") # On non-Darwin OS, don't explicitly specify DWARF version. import re @@ -748,7 +748,7 @@ def check_llvm_ir(): # Try Darwin, require DWARF-2 m = tvm.tir.build(mod, target="llvm -mtriple=x86_64-apple-darwin-macho") - ll = m.get_source("ll") + ll = m.inspect_source("ll") assert re.search(r"""i32 4, !"Dwarf Version", i32 2""", ll) assert re.search(r"""llvm.dbg.value""", ll) @@ -802,9 +802,9 @@ def test_llvm_crt_static_lib(): mod.with_attr("system_lib_prefix", ""), target=tvm.target.Target("llvm"), ) - module.get_source() + module.inspect_source() with utils.tempdir() as temp: - module.save(temp.relpath("test.o")) + module.write_to_file(temp.relpath("test.o")) @tvm.testing.requires_llvm @@ -829,7 +829,7 @@ def make_call_extern(caller, callee): "Kirby": make_call_extern("Kirby", "Fred"), } mod = tvm.IRModule(functions=functions) - ir_text = tvm.tir.build(mod, target="llvm").get_source("ll") + ir_text = tvm.tir.build(mod, target="llvm").inspect_source("ll") # Skip functions whose names start with _. matches = re.findall(r"^define[^@]*@([a-zA-Z][a-zA-Z0-9_]*)", ir_text, re.MULTILINE) assert matches == sorted(matches) @@ -930,7 +930,7 @@ def test_llvm_target_attributes(): target = tvm.target.Target(target_llvm, host=target_llvm) module = tvm.tir.build(sch.mod, target=target) - llvm_ir = module.get_source() + llvm_ir = module.inspect_source() llvm_ir_lines = llvm_ir.split("\n") attribute_definitions = dict() diff --git a/tests/python/codegen/test_target_codegen_llvm_vla.py b/tests/python/codegen/test_target_codegen_llvm_vla.py index 7ca3083dd5e3..8930159481cb 100644 --- a/tests/python/codegen/test_target_codegen_llvm_vla.py +++ b/tests/python/codegen/test_target_codegen_llvm_vla.py @@ -46,7 +46,7 @@ def main(A: T.Buffer((5,), "int32")): with tvm.target.Target(target): build_mod = tvm.tir.build(main) - llvm = build_mod.get_source() + llvm = build_mod.inspect_source() assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM." @@ -68,7 +68,7 @@ def my_func(a: T.handle, b: T.handle): with tvm.target.Target(target): mod = tvm.tir.build(my_func) - llvm = mod.get_source("ll") + llvm = mod.inspect_source("ll") assert re.findall(r"load ", llvm), "No scalable load in generated LLVM." assert re.findall(r" store ", llvm), "No scalable store in generated LLVM." @@ -90,7 +90,7 @@ def my_func(a: T.handle): with tvm.target.Target(target): mod = tvm.tir.build(my_func) - llvm = mod.get_source("ll") + llvm = mod.inspect_source("ll") assert re.findall( r"shufflevector \( insertelement \(", llvm ), "No scalable broadcast in generated LLVM." @@ -114,7 +114,7 @@ def before(a: T.handle): with tvm.target.Target(target): out = tvm.tir.build(before) - ll = out.get_source("ll") + ll = out.inspect_source("ll") assert "get.active.lane.mask" in ll @@ -139,7 +139,7 @@ def before(a: T.handle, b: T.handle): with tvm.target.Target(target): out = tvm.tir.build(before) - ll = out.get_source("ll") + ll = out.inspect_source("ll") assert "get.active.lane.mask" in ll assert "llvm.masked.load" in ll assert "llvm.masked.store" in ll diff --git a/tests/python/codegen/test_target_codegen_metal.py b/tests/python/codegen/test_target_codegen_metal.py index 2d669081e347..6b413d532371 100644 --- a/tests/python/codegen/test_target_codegen_metal.py +++ b/tests/python/codegen/test_target_codegen_metal.py @@ -187,7 +187,7 @@ def compile_metal(src, target): mod = tvm.IRModule({"main": func}) f = tvm.compile(mod, target="metal") - src: str = f.imported_modules[0].get_source() + src: str = f.imports[0].inspect_source() occurrences = src.count("struct func_kernel_args_t") assert occurrences == 1, occurrences diff --git a/tests/python/codegen/test_target_codegen_opencl.py b/tests/python/codegen/test_target_codegen_opencl.py index cbdb60477b06..4eb96747bcee 100644 --- a/tests/python/codegen/test_target_codegen_opencl.py +++ b/tests/python/codegen/test_target_codegen_opencl.py @@ -140,7 +140,7 @@ def check_erf(dev, n, dtype): sch.bind(x, "threadIdx.x") fun = tvm.tir.build(sch.mod, target=target) - source_str = fun.imported_modules[0].get_source() + source_str = fun.imports[0].inspect_source() matches = re.findall("erf", source_str) error_matches = re.findall("erff", source_str) assert len(matches) == 1 and len(error_matches) == 0 @@ -180,7 +180,7 @@ def check_type_casting(ctx, n, dtype): fun = tvm.tir.build(sch.mod, target=target) c = tvm.nd.empty((n,), dtype, ctx) - assembly = fun.imported_modules[0].get_source() + assembly = fun.imports[0].inspect_source() lcond = "convert_int4(((convert_uint4(((uint4)(((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3)))))" rcond = "(convert_uint4(((((int4)(((convert_int(get_local_id(0))))+(1*0), ((convert_int(get_local_id(0))))+(1*1), ((convert_int(get_local_id(0))))+(1*2), ((convert_int(get_local_id(0))))+(1*3))) % ((int4)(3, 3, 3, 3))) == ((int4)(1, 1, 1, 1))))))))" pattern_cond = "({} && {})".format(lcond, rcond) @@ -211,7 +211,7 @@ def _check(target, n, dtype): sch.bind(x, "threadIdx.x") fun = tvm.tir.build(sch.mod, target=target) - assembly = fun.imported_modules[0].get_source() + assembly = fun.imports[0].inspect_source() if "adreno" in target: pattern = "convert_float" else: diff --git a/tests/python/codegen/test_target_codegen_riscv.py b/tests/python/codegen/test_target_codegen_riscv.py index b06aeb4ced06..1a30ab203f04 100644 --- a/tests/python/codegen/test_target_codegen_riscv.py +++ b/tests/python/codegen/test_target_codegen_riscv.py @@ -36,7 +36,7 @@ def load_vec(A: T.Buffer((N,), "int8")): f = tvm.tir.build(load_vec, target) # Check RVV `vsetvli` prensence - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") if target_has_features("v"): assert "vsetvli" in assembly else: diff --git a/tests/python/codegen/test_target_codegen_vulkan.py b/tests/python/codegen/test_target_codegen_vulkan.py index 89acf598d6e3..a523ae037794 100644 --- a/tests/python/codegen/test_target_codegen_vulkan.py +++ b/tests/python/codegen/test_target_codegen_vulkan.py @@ -86,7 +86,7 @@ def test_vector_comparison(target, dev, dtype): # Verify we generate the boolx4 type declaration and the OpSelect # v4{float,half,int} instruction - assembly = f.imported_modules[0].get_source() + assembly = f.imports[0].inspect_source() matches = re.findall("%v4bool = OpTypeVector %bool 4", assembly) assert len(matches) == 1 matches = re.findall("OpSelect %v4.*", assembly) diff --git a/tests/python/codegen/test_target_codegen_x86.py b/tests/python/codegen/test_target_codegen_x86.py index 51d648f2c4a9..8664d5ceb732 100644 --- a/tests/python/codegen/test_target_codegen_x86.py +++ b/tests/python/codegen/test_target_codegen_x86.py @@ -41,7 +41,7 @@ def fp16_to_fp32(target, width, match=None, not_match=None): sch.vectorize(sch.get_loops("B")[1]) f = tvm.tir.build(sch.mod, target=target) - assembly = f.get_source("asm").splitlines() + assembly = f.inspect_source("asm").splitlines() if match: matches = [l for l in assembly if re.search(match, l)] assert matches diff --git a/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py b/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py index 1b75dd5bc915..d3adbc12c922 100644 --- a/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py +++ b/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py @@ -218,7 +218,7 @@ def _benchmark_hexagon_elementwise_add_kernel( # Create an actual Hexagon-native shared object file, initially stored on the # host's file system... host_dso_binary_path = os.path.join(host_files_dir_path, "test_binary.so") - built_module.save(host_dso_binary_path) + built_module.write_to_file(host_dso_binary_path) print(f"SAVED BINARY TO HOST PATH: {host_dso_binary_path}") # Upload the .so to the Android device's file system (or wherever is appropriate diff --git a/tests/python/contrib/test_hexagon/test_benchmark_maxpool2d.py b/tests/python/contrib/test_hexagon/test_benchmark_maxpool2d.py index b6b1f8fa73d6..7d556e8bae73 100644 --- a/tests/python/contrib/test_hexagon/test_benchmark_maxpool2d.py +++ b/tests/python/contrib/test_hexagon/test_benchmark_maxpool2d.py @@ -261,7 +261,7 @@ def test_maxpool2d_nhwc( # Save a local copy of the Hexagon object code (in the form of a .so file) # to allow post-mortem inspection. host_dso_binary_path = os.path.join(host_files_dir_path, "test_binary.so") - built_module.save(host_dso_binary_path) + built_module.write_to_file(host_dso_binary_path) print(f"SAVED BINARY TO HOST PATH: {host_dso_binary_path}") hexagon_mod = hexagon_session.load_module(built_module) diff --git a/tests/python/contrib/test_hexagon/test_sigmoid.py b/tests/python/contrib/test_hexagon/test_sigmoid.py index c6196ce42517..b873f606e619 100644 --- a/tests/python/contrib/test_hexagon/test_sigmoid.py +++ b/tests/python/contrib/test_hexagon/test_sigmoid.py @@ -94,9 +94,9 @@ def test_sigmoid( with tvm.transform.PassContext(opt_level=3): runtime_module = tvm.compile(tir_s.mod, target=get_hexagon_target("v69")) - assert "hvx_sigmoid" in runtime_module.get_source("asm") - assert "vmin" in runtime_module.get_source("asm") - assert "vmax" in runtime_module.get_source("asm") + assert "hvx_sigmoid" in runtime_module.inspect_source("asm") + assert "vmin" in runtime_module.inspect_source("asm") + assert "vmax" in runtime_module.inspect_source("asm") mod = hexagon_session.load_module(runtime_module) mod(input_data, output_data) diff --git a/tests/python/contrib/test_hexagon/test_vtcm.py b/tests/python/contrib/test_hexagon/test_vtcm.py index 2795f5630163..eec48a972ea2 100644 --- a/tests/python/contrib/test_hexagon/test_vtcm.py +++ b/tests/python/contrib/test_hexagon/test_vtcm.py @@ -50,7 +50,7 @@ def test_vtcm_building(): sch = get_scale_by_two_schedule() target = get_hexagon_target("v68") built = tvm.compile(sch.mod, target=target) - assert "global.vtcm" in built.get_source("asm") + assert "global.vtcm" in built.inspect_source("asm") @tvm.testing.requires_hexagon diff --git a/tests/python/ir/test_roundtrip_runtime_module.py b/tests/python/ir/test_roundtrip_runtime_module.py index 3723cc6c112c..e6fca273a025 100644 --- a/tests/python/ir/test_roundtrip_runtime_module.py +++ b/tests/python/ir/test_roundtrip_runtime_module.py @@ -25,11 +25,11 @@ def test_csource_module(): mod = tvm.runtime._ffi_api.CSourceModuleCreate("", "cc", [], []) - assert mod.type_key == "c" - assert mod.is_binary_serializable + assert mod.kind == "c" + assert mod.is_binary_serializable() new_mod = tvm.ir.load_json(tvm.ir.save_json(mod)) - assert new_mod.type_key == "c" - assert new_mod.is_binary_serializable + assert new_mod.kind == "c" + assert new_mod.is_binary_serializable() if __name__ == "__main__": diff --git a/tests/python/relax/backend/clml/test_clml_codegen.py b/tests/python/relax/backend/clml/test_clml_codegen.py index b03d6afa1c9b..29448774d69b 100644 --- a/tests/python/relax/backend/clml/test_clml_codegen.py +++ b/tests/python/relax/backend/clml/test_clml_codegen.py @@ -52,7 +52,7 @@ def compare_codegen(clml_mod, clml_codegen): - source = clml_mod.attrs["external_mods"][0].get_source() + source = clml_mod.attrs["external_mods"][0].inspect_source() codegen = json.loads(source)["nodes"] for node in range(len(codegen)): if codegen[node]["op"] == "input" or codegen[node]["op"] == "const": diff --git a/tests/python/relax/test_vm_instrument.py b/tests/python/relax/test_vm_instrument.py index c5f293114f3c..8c4d728da18b 100644 --- a/tests/python/relax/test_vm_instrument.py +++ b/tests/python/relax/test_vm_instrument.py @@ -93,7 +93,7 @@ def test_lib_comparator(): ex = get_exec_int32(data_np.shape) vm = relax.VirtualMachine(ex, tvm.cpu()) # compare against library module - cmp = LibCompareVMInstrument(vm.module.imported_modules[0], tvm.cpu(), verbose=False) + cmp = LibCompareVMInstrument(vm.module.imports[0], tvm.cpu(), verbose=False) vm.set_instrument(cmp) vm["main"](tvm.nd.array(data_np)) diff --git a/tests/python/runtime/test_runtime_module_export.py b/tests/python/runtime/test_runtime_module_export.py index 8897837a26af..0db1fa93dc2a 100644 --- a/tests/python/runtime/test_runtime_module_export.py +++ b/tests/python/runtime/test_runtime_module_export.py @@ -40,16 +40,16 @@ def test_import_static_library(): assert mod0.implements_function("myadd0") assert mod1.implements_function("myadd1") - assert mod1.is_dso_exportable + assert mod1.is_compilation_exportable() # mod1 is currently an 'llvm' module. # Save and reload it as a vanilla 'static_library'. temp = utils.tempdir() mod1_o_path = temp.relpath("mod1.o") - mod1.save(mod1_o_path) + mod1.write_to_file(mod1_o_path) mod1_o = tvm.runtime.load_static_library(mod1_o_path, ["myadd1"]) assert mod1_o.implements_function("myadd1") - assert mod1_o.is_dso_exportable + assert mod1_o.is_compilation_exportable() # Import mod1 as a static library into mod0 and compile to its own DSO. mod0.import_module(mod1_o) @@ -58,13 +58,13 @@ def test_import_static_library(): # The imported mod1 is statically linked into mod0. loaded_lib = tvm.runtime.load_module(mod0_dso_path) - assert loaded_lib.type_key == "library" - assert len(loaded_lib.imported_modules) == 0 + assert loaded_lib.kind == "library" + assert len(loaded_lib.imports) == 0 assert loaded_lib.implements_function("myadd0") assert loaded_lib.get_function("myadd0") assert loaded_lib.implements_function("myadd1") assert loaded_lib.get_function("myadd1") - assert not loaded_lib.is_dso_exportable + assert not loaded_lib.is_compilation_exportable() if __name__ == "__main__": diff --git a/tests/python/runtime/test_runtime_module_load.py b/tests/python/runtime/test_runtime_module_load.py index 79b95256f9fa..d22d40f6f2b1 100644 --- a/tests/python/runtime/test_runtime_module_load.py +++ b/tests/python/runtime/test_runtime_module_load.py @@ -64,7 +64,7 @@ def save_object(names): ) m = tvm.tir.build(mod, target=target) for name in names: - m.save(name) + m.write_to_file(name) path_obj = temp.relpath("test.o") path_ll = temp.relpath("test.ll") @@ -169,8 +169,8 @@ def check_llvm(): path1 = temp.relpath("myadd1.o") path2 = temp.relpath("myadd2.o") path_dso = temp.relpath("mylib.so") - fadd1.save(path1) - fadd2.save(path2) + fadd1.write_to_file(path1) + fadd2.write_to_file(path2) # create shared library with multiple functions cc.create_shared(path_dso, [path1, path2]) m = tvm.runtime.load_module(path_dso) @@ -195,8 +195,8 @@ def check_system_lib(): path1 = temp.relpath("myadd1.o") path2 = temp.relpath("myadd2.o") path_dso = temp.relpath("mylib.so") - fadd1.save(path1) - fadd2.save(path2) + fadd1.write_to_file(path1) + fadd2.write_to_file(path2) cc.create_shared(path_dso, [path1, path2]) def popen_check(): diff --git a/tests/python/runtime/test_runtime_module_property.py b/tests/python/runtime/test_runtime_module_property.py index 83e535e1ac83..a071f9774323 100644 --- a/tests/python/runtime/test_runtime_module_property.py +++ b/tests/python/runtime/test_runtime_module_property.py @@ -21,9 +21,9 @@ def checker(mod, expected): - assert mod.is_binary_serializable == expected["is_binary_serializable"] - assert mod.is_runnable == expected["is_runnable"] - assert mod.is_dso_exportable == expected["is_dso_exportable"] + assert mod.is_binary_serializable() == expected["is_binary_serializable()"] + assert mod.is_runnable() == expected["is_runnable"] + assert mod.is_compilation_exportable() == expected["is_compilation_exportable()"] def create_csource_module(): @@ -39,12 +39,20 @@ def create_llvm_module(): def test_property(): checker( create_csource_module(), - expected={"is_binary_serializable": True, "is_runnable": False, "is_dso_exportable": True}, + expected={ + "is_binary_serializable()": True, + "is_runnable": False, + "is_compilation_exportable()": True, + }, ) checker( create_llvm_module(), - expected={"is_binary_serializable": False, "is_runnable": True, "is_dso_exportable": True}, + expected={ + "is_binary_serializable()": False, + "is_runnable": True, + "is_compilation_exportable()": True, + }, ) diff --git a/tests/python/runtime/test_runtime_rpc.py b/tests/python/runtime/test_runtime_rpc.py index e696cbcf086c..ac8653012ace 100644 --- a/tests/python/runtime/test_runtime_rpc.py +++ b/tests/python/runtime/test_runtime_rpc.py @@ -80,7 +80,7 @@ def verify_rpc(remote, target, shape, dtype): b = tvm.nd.array(np.zeros(shape).astype(A.dtype), device=dev) temp = utils.tempdir() path_dso = temp.relpath("dev_lib.o") - f.save(path_dso) + f.write_to_file(path_dso) remote.upload(path_dso) f = remote.load_module("dev_lib.o") f(a, b) diff --git a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py index 1858c00e8662..d4c93bb24ae9 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py @@ -393,7 +393,7 @@ def postproc_if_missing_async_support(): # way, even though the generated code doesn't compile on platforms # that do not support async, the comparison against an expected # output can still be performed. We cannot use - # `mod.get_source()`, as that contains the source after all + # `mod.inspect_source()`, as that contains the source after all # post-processing. original_code = None diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc index 4cf45075edcb..b33724c722d7 100644 --- a/web/emcc/tvmjs_support.cc +++ b/web/emcc/tvmjs_support.cc @@ -169,9 +169,9 @@ class AsyncLocalSession : public LocalSession { // special handle time evaluator. try { ffi::Function retfunc = this->GetTimeEvaluator( - args[0].cast>(), args[1].cast(), args[2].cast(), - args[3].cast(), args[4].cast(), args[5].cast(), args[6].cast(), - args[7].cast(), args[8].cast(), args[9].cast()); + args[0].cast>(), args[1].cast(), + args[2].cast(), args[3].cast(), args[4].cast(), args[5].cast(), + args[6].cast(), args[7].cast(), args[8].cast(), args[9].cast()); ffi::Any rv; rv = retfunc; this->EncodeReturn(std::move(rv), [&](ffi::PackedArgs encoded_args) { @@ -252,7 +252,7 @@ class AsyncLocalSession : public LocalSession { std::optional async_wait_; // time evaluator - ffi::Function GetTimeEvaluator(Optional opt_mod, std::string name, int device_type, + ffi::Function GetTimeEvaluator(Optional opt_mod, std::string name, int device_type, int device_id, int number, int repeat, int min_repeat_ms, int limit_zero_time_iterations, int cooldown_interval_ms, int repeats_to_cooldown) { @@ -261,10 +261,10 @@ class AsyncLocalSession : public LocalSession { dev.device_id = device_id; if (opt_mod.defined()) { - Module m = opt_mod.value(); - std::string tkey = m->type_key(); - return WrapWasmTimeEvaluator(m.GetFunction(name, false), dev, number, repeat, min_repeat_ms, - limit_zero_time_iterations, cooldown_interval_ms, + ffi::Module m = opt_mod.value(); + std::string tkey = m->kind(); + return WrapWasmTimeEvaluator(m->GetFunction(name, false).value(), dev, number, repeat, + min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, repeats_to_cooldown); } else { auto pf = tvm::ffi::Function::GetGlobal(name); diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index ae8bea5524f6..6e2664a93bff 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -37,9 +37,7 @@ #include "src/runtime/cpu_device_api.cc" #include "src/runtime/device_api.cc" #include "src/runtime/file_utils.cc" -#include "src/runtime/library_module.cc" #include "src/runtime/logging.cc" -#include "src/runtime/module.cc" #include "src/runtime/ndarray.cc" #include "src/runtime/profiling.cc" #include "src/runtime/rpc/rpc_channel.cc" @@ -48,12 +46,14 @@ #include "src/runtime/rpc/rpc_local_session.cc" #include "src/runtime/rpc/rpc_module.cc" #include "src/runtime/rpc/rpc_session.cc" -#include "src/runtime/system_library.cc" #include "src/runtime/workspace_pool.cc" // relax setup #include "ffi/src/ffi/container.cc" #include "ffi/src/ffi/dtype.cc" #include "ffi/src/ffi/error.cc" +#include "ffi/src/ffi/extra/library_module.cc" +#include "ffi/src/ffi/extra/library_module_system_lib.cc" +#include "ffi/src/ffi/extra/module.cc" #include "ffi/src/ffi/function.cc" #include "ffi/src/ffi/ndarray.cc" #include "ffi/src/ffi/object.cc" diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index 7e9f7c0f45ab..cd50bc067983 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -29,6 +29,7 @@ #define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 1 #define DMLC_USE_LOGGING_LIBRARY +#include #include #include #include @@ -156,7 +157,7 @@ WebGPUThreadEntry* WebGPUThreadEntry::ThreadLocal() { return &inst; } -class WebGPUModuleNode final : public runtime::ModuleNode { +class WebGPUModuleNode final : public ffi::ModuleObj { public: explicit WebGPUModuleNode(std::unordered_map smap, std::unordered_map fmap) @@ -166,9 +167,9 @@ class WebGPUModuleNode final : public runtime::ModuleNode { create_shader_ = *fp; } - const char* type_key() const final { return "webgpu"; } + const char* kind() const final { return "webgpu"; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + Optional GetFunction(const String& name) final { // special function if (name == "webgpu.get_fmap") { return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) { @@ -206,15 +207,15 @@ class WebGPUModuleNode final : public runtime::ModuleNode { info.Save(&writer); return create_shader_(os.str(), it->second); } else { - return ffi::Function(nullptr); + return std::nullopt; } } - int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; }; + int GetPropertyMask() const final { return ffi::Module::kBinarySerializable; }; - void SaveToBinary(dmlc::Stream* stream) final { LOG(FATAL) << "Not implemented"; } + ffi::Bytes SaveToBytes() const final { LOG(FATAL) << "Not implemented"; } - String GetSource(const String& format) final { + String InspectSource(const String& format) const final { // can only return source code. return source_; } @@ -232,21 +233,22 @@ class WebGPUModuleNode final : public runtime::ModuleNode { ffi::TypedFunction create_shader_; }; -Module WebGPUModuleLoadBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); +ffi::Module WebGPUModuleLoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; std::unordered_map smap; std::unordered_map fmap; stream->Read(&fmap); stream->Read(&smap); - return Module(make_object(smap, fmap)); + return ffi::Module(make_object(smap, fmap)); } // for now webgpu is hosted via a vulkan module. TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("runtime.module.loadbinary_webgpu", WebGPUModuleLoadBinary) + .def("ffi.Module.load_from_bytes.webgpu", WebGPUModuleLoadFromBytes) .def_packed("device_api.webgpu", [](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = WebGPUDeviceAPI::Global(); *rv = static_cast(ptr); diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 75f4de855581..071b2eed68e4 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -189,8 +189,8 @@ class RuntimeContext implements Disposable { this.functionListGlobalNamesFunctor = getGlobalFunc( "ffi.FunctionListGlobalNamesFunctor" ); - this.moduleGetFunction = getGlobalFunc("runtime.ModuleGetFunction"); - this.moduleImport = getGlobalFunc("runtime.ModuleImport"); + this.moduleGetFunction = getGlobalFunc("ffi.ModuleGetFunction"); + this.moduleImport = getGlobalFunc("ffi.ModuleImportModule"); this.ndarrayEmpty = getGlobalFunc("runtime.TVMArrayAllocWithScope"); this.ndarrayCopyFromTo = getGlobalFunc("runtime.TVMArrayCopyFromTo"); this.ndarrayCopyFromJSBytes = getGlobalFunc("tvmjs.runtime.NDArrayCopyFromBytes"); @@ -199,7 +199,7 @@ class RuntimeContext implements Disposable { this.arrayGetSize = getGlobalFunc("ffi.ArraySize"); this.arrayMake = getGlobalFunc("ffi.Array"); this.arrayConcat = getGlobalFunc("tvmjs.runtime.ArrayConcat"); - this.getSysLib = getGlobalFunc("runtime.SystemLib"); + this.getSysLib = getGlobalFunc("ffi.SystemLib"); this.arrayCacheGet = getGlobalFunc("vm.builtin.ndarray_cache.get"); this.arrayCacheRemove = getGlobalFunc("vm.builtin.ndarray_cache.remove"); this.arrayCacheUpdate = getGlobalFunc("vm.builtin.ndarray_cache.update"); @@ -1900,7 +1900,7 @@ export class Instance implements Disposable { (handle: number, lib: FFILibrary, ctx: RuntimeContext) => { return new TVMArray(handle, lib, ctx); }); - this.registerObjectConstructor("runtime.Module", + this.registerObjectConstructor("ffi.Module", (handle: number, lib: FFILibrary, ctx: RuntimeContext) => { return new Module(handle, lib, ctx); });