Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,14 +290,14 @@ namespace symbol {
constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi_library_ctx";
/*! \brief Global variable to store binary data alongside a library module. */
constexpr const char* tvm_ffi_library_bin = "__tvm_ffi_library_bin";
/*! \brief Placeholder for the module's entry function. */
constexpr const char* tvm_ffi_main = "__tvm_ffi_main__";
/*! \brief global function to set device */
constexpr const char* tvm_set_device = "__tvm_set_device";
/*! \brief Auxiliary counter to global barrier. */
constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
/*! \brief Prepare the global barrier before kernels that uses global barrier. */
constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier";
/*! \brief Placeholder for the module's entry function. */
constexpr const char* tvm_module_main = "__tvm_main__";
} // namespace symbol

// implementations of inline functions.
Expand Down
2 changes: 1 addition & 1 deletion jvm/core/src/main/java/org/apache/tvm/Module.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ private static Function getApi(String name) {
}

private Function entry = null;
private final String entryName = "__tvm_main__";
private final String entryName = "__tvm_ffi_main__";


/**
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class Module(tvm.ffi.Object):

def __new__(cls):
instance = super(Module, cls).__new__(cls) # pylint: disable=no-value-for-parameter
instance.entry_name = "__tvm_main__"
instance.entry_name = "__tvm_ffi_main__"
instance._entry = None
return instance

Expand All @@ -118,7 +118,7 @@ def entry_func(self):
"""
if self._entry:
return self._entry
self._entry = self.get_function("__tvm_main__")
self._entry = self.get_function("__tvm_ffi_main__")
return self._entry

def implements_function(self, name, query_imports=False):
Expand Down
1 change: 0 additions & 1 deletion src/runtime/cuda/cuda_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ class CUDAPrepGlobalBarrier {
ffi::Function CUDAModuleNode::GetFunction(const String& name,
const ObjectPtr<Object>& sptr_to_self) {
ICHECK_EQ(sptr_to_self.get(), this);
ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
if (name == symbol::tvm_prepare_global_barrier) {
return ffi::Function(CUDAPrepGlobalBarrier(this, sptr_to_self));
}
Expand Down
10 changes: 1 addition & 9 deletions src/runtime/library_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,7 @@ class LibraryModuleNode final : public ModuleNode {

ffi::Function GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final {
TVMFFISafeCallType faddr;
if (name == runtime::symbol::tvm_module_main) {
const char* entry_name =
reinterpret_cast<const char*>(lib_->GetSymbol(runtime::symbol::tvm_module_main));
ICHECK(entry_name != nullptr)
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
faddr = reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(entry_name));
} else {
faddr = reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(name.c_str()));
}
faddr = reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(name.c_str()));
if (faddr == nullptr) return ffi::Function();
return packed_func_wrapper_(faddr, sptr_to_self);
}
Expand Down
1 change: 0 additions & 1 deletion src/runtime/metal/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args)
ffi::Function ret;
AUTORELEASEPOOL {
ICHECK_EQ(sptr_to_self.get(), this);
ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
auto it = fmap_.find(name);
if (it == fmap_.end()) {
ret = ffi::Function();
Expand Down
1 change: 0 additions & 1 deletion src/runtime/opencl/opencl_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ cl::OpenCLWorkspace* OpenCLModuleNodeBase::GetGlobalWorkspace() {
ffi::Function OpenCLModuleNodeBase::GetFunction(const String& name,
const ObjectPtr<Object>& sptr_to_self) {
ICHECK_EQ(sptr_to_self.get(), this);
ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
auto it = fmap_.find(name);
if (it == fmap_.end()) return ffi::Function();
const FunctionInfo& info = it->second;
Expand Down
1 change: 0 additions & 1 deletion src/runtime/rocm/rocm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ class ROCMWrappedFunc {
ffi::Function ROCMModuleNode::GetFunction(const String& name,
const ObjectPtr<Object>& sptr_to_self) {
ICHECK_EQ(sptr_to_self.get(), this);
ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
auto it = fmap_.find(name);
if (it == fmap_.end()) return ffi::Function();
const FunctionInfo& info = it->second;
Expand Down
1 change: 0 additions & 1 deletion src/runtime/vulkan/vulkan_wrapped_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ VulkanModuleNode::~VulkanModuleNode() {
ffi::Function VulkanModuleNode::GetFunction(const String& name,
const ObjectPtr<Object>& sptr_to_self) {
ICHECK_EQ(sptr_to_self.get(), this);
ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
auto it = fmap_.find(name);
if (it == fmap_.end()) return ffi::Function();
const FunctionInfo& info = it->second;
Expand Down
50 changes: 32 additions & 18 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,28 +229,42 @@ void CodeGenCPU::AddFunction(const GlobalVar& gvar, const PrimFunc& func) {
}

void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) {
llvm::Function* f = module_->getFunction(entry_func_name);
ICHECK(f) << "Function " << entry_func_name << "does not in module";
llvm::Type* type = llvm::ArrayType::get(t_char_, entry_func_name.length() + 1);
llvm::GlobalVariable* global =
new llvm::GlobalVariable(*module_, type, true, llvm::GlobalValue::WeakAnyLinkage, nullptr,
runtime::symbol::tvm_module_main);
#if TVM_LLVM_VERSION >= 100
global->setAlignment(llvm::Align(1));
#else
global->setAlignment(1);
#endif
// comdat is needed for windows select any linking to work
// set comdat to Any(weak linking)
// create a wrapper function with tvm_ffi_main name and redirects to the entry function
llvm::Function* target_func = module_->getFunction(entry_func_name);
ICHECK(target_func) << "Function " << entry_func_name << " does not exist in module";

// Create wrapper function
llvm::Function* wrapper_func =
llvm::Function::Create(target_func->getFunctionType(), llvm::Function::WeakAnyLinkage,
runtime::symbol::tvm_ffi_main, module_.get());

// Set attributes (Windows comdat, DLL export, etc.)
if (llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) {
llvm::Comdat* comdat = module_->getOrInsertComdat(runtime::symbol::tvm_module_main);
llvm::Comdat* comdat = module_->getOrInsertComdat(runtime::symbol::tvm_ffi_main);
comdat->setSelectionKind(llvm::Comdat::Any);
global->setComdat(comdat);
wrapper_func->setComdat(comdat);
}

global->setInitializer(
llvm::ConstantDataArray::getString(*llvm_target_->GetContext(), entry_func_name));
global->setDLLStorageClass(llvm::GlobalVariable::DLLExportStorageClass);
wrapper_func->setCallingConv(llvm::CallingConv::C);
wrapper_func->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);

// Create simple tail call
llvm::BasicBlock* entry =
llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry", wrapper_func);
builder_->SetInsertPoint(entry);

// Forward all arguments to target function
std::vector<llvm::Value*> call_args;
for (llvm::Value& arg : wrapper_func->args()) {
call_args.push_back(&arg);
}

llvm::Value* result = builder_->CreateCall(target_func, call_args);
if (target_func->getReturnType()->isVoidTy()) {
builder_->CreateRetVoid();
} else {
builder_->CreateRet(result);
}
}

std::unique_ptr<llvm::Module> CodeGenCPU::Finish() {
Expand Down
10 changes: 1 addition & 9 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,7 @@ ffi::Function LLVMModuleNode::GetFunction(const String& name,

TVMFFISafeCallType faddr;
With<LLVMTarget> llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_));
if (name == runtime::symbol::tvm_module_main) {
const char* entry_name = reinterpret_cast<const char*>(
GetGlobalAddr(runtime::symbol::tvm_module_main, *llvm_target));
ICHECK(entry_name != nullptr) << "Symbol " << runtime::symbol::tvm_module_main
<< " is not presented";
faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(entry_name, *llvm_target));
} else {
faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(name, *llvm_target));
}
faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(name, *llvm_target));
if (faddr == nullptr) return ffi::Function();
return tvm::runtime::WrapFFIFunction(faddr, sptr_to_self);
}
Expand Down
4 changes: 2 additions & 2 deletions src/target/source/codegen_c_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func,
<< "CodeGenCHost: The entry func must have the global_symbol attribute, "
<< "but function " << gvar << " only has attributes " << func->attrs;

function_names_.push_back(runtime::symbol::tvm_module_main);
function_names_.push_back(runtime::symbol::tvm_ffi_main);
stream << "// CodegenC: NOTE: Auto-generated entry function\n";
PrintFuncPrefix(stream);
PrintType(func->ret_type, stream);
stream << " " << tvm::runtime::symbol::tvm_module_main
stream << " " << tvm::runtime::symbol::tvm_ffi_main
<< "(void* self, void* args,int num_args, void* result) {\n";
stream << " return " << global_symbol.value() << "(self, args, num_args, result);\n";
stream << "}\n";
Expand Down
8 changes: 6 additions & 2 deletions tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,13 @@ def evaluate(

if tvm.testing.utils.IS_IN_CI:
# Run with reduced number and repeat for CI
timer = module.time_evaluator("__tvm_main__", hexagon_session.device, number=1, repeat=1)
timer = module.time_evaluator(
"__tvm_ffi_main__", hexagon_session.device, number=1, repeat=1
)
else:
timer = module.time_evaluator("__tvm_main__", hexagon_session.device, number=10, repeat=10)
timer = module.time_evaluator(
"__tvm_ffi_main__", hexagon_session.device, number=10, repeat=10
)

time = timer(a_hexagon, b_hexagon, c_hexagon)
if expected_output is not None:
Expand Down
2 changes: 1 addition & 1 deletion tests/python/contrib/test_hexagon/test_parallel_hvx.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def evaluate(hexagon_session, shape_dtypes, expected_output_producer, sch):
repeat = 1

timer = module.time_evaluator(
"__tvm_main__", hexagon_session.device, number=number, repeat=repeat
"__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat
)
runtime = timer(a_hexagon, b_hexagon, c_hexagon)
tvm.testing.assert_allclose(c_hexagon.numpy(), expected_output_producer(c_shape, a, b))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def setup_and_run(hexagon_session, sch, a, b, c, operations, mem_scope="global")
repeat = 1

timer = module.time_evaluator(
"__tvm_main__", hexagon_session.device, number=number, repeat=repeat
"__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat
)
time = timer(a_hexagon, b_hexagon, c_hexagon)
gops = round(operations * 128 * 3 / time.mean / 1e9, 4)
Expand Down Expand Up @@ -365,7 +365,7 @@ def setup_and_run_preallocated(hexagon_session, sch, a, b, c, operations):
repeat = 1

timer = module.time_evaluator(
"__tvm_main__", hexagon_session.device, number=number, repeat=repeat
"__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat
)
time = timer(a_hexagon, b_hexagon, c_hexagon, a_vtcm_hexagon, b_vtcm_hexagon, c_vtcm_hexagon)
gops = round(operations * 128 * 3 / time.mean / 1e9, 4)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/contrib/test_hexagon/test_parallel_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def evaluate(hexagon_session, operations, expected, sch):
repeat = 1

timer = module.time_evaluator(
"__tvm_main__", hexagon_session.device, number=number, repeat=repeat
"__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat
)
runtime = timer(a_hexagon, b_hexagon, c_hexagon)

Expand Down
8 changes: 6 additions & 2 deletions tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,13 @@ def evaluate(hexagon_session, sch, size):

if tvm.testing.utils.IS_IN_CI:
# Run with reduced number and repeat for CI
timer = module.time_evaluator("__tvm_main__", hexagon_session.device, number=1, repeat=1)
timer = module.time_evaluator(
"__tvm_ffi_main__", hexagon_session.device, number=1, repeat=1
)
else:
timer = module.time_evaluator("__tvm_main__", hexagon_session.device, number=10, repeat=10)
timer = module.time_evaluator(
"__tvm_ffi_main__", hexagon_session.device, number=10, repeat=10
)

runtime = timer(a_hexagon, a_vtcm_hexagon)

Expand Down
Loading