Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ using PassContext = tvm::transform::PassContext;
using PassContextNode = tvm::transform::PassContextNode;
using Sequential = tvm::transform::Sequential;

/*!
* \brief RelayToTIR tvm::transform::Pass specific to a TargetKind
*
* Called before the default lowering passes.
*
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that can provide information for the optimization.
*
* \return The transformed module.
*/
using FTVMRelayToTIR = tvm::transform::Pass;

/*
* \brief Create a function pass.
*
Expand Down
10 changes: 0 additions & 10 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#define TVM_TARGET_TARGET_H_

#include <tvm/ir/expr.h>
#include <tvm/ir/module.h>
#include <tvm/node/node.h>
#include <tvm/support/with.h>
#include <tvm/target/target_kind.h>
Expand Down Expand Up @@ -284,14 +283,5 @@ class Target : public ObjectRef {
*/
void CheckAndUpdateHostConsistency(Target* target, Target* host);

/*!
* \brief Check and update host field of the given legacy heterogeneous targets and
* target host.Note that this function is for legacy target api compatibility issue only,
* not recommended for other use.
* \param ir_modules The pointer to a Map objects with keys being Target objects
* \param host The Target typed object for target host to be updated
*/
void CheckAndUpdateHostConsistency(Map<Target, IRModule>* ir_modules, Target* host);

} // namespace tvm
#endif // TVM_TARGET_TARGET_H_
27 changes: 1 addition & 26 deletions include/tvm/target/target_kind.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#ifndef TVM_TARGET_TARGET_KIND_H_
#define TVM_TARGET_TARGET_KIND_H_

#include <tvm/ir/transform.h>
#include <tvm/node/attr_registry_map.h>
#include <tvm/node/node.h>

Expand All @@ -50,31 +49,7 @@ using TargetFeatures = Map<String, ObjectRef>;
* \return The transformed Target JSON object.
*/
using TargetJSON = Map<String, ObjectRef>;
using FTVMTargetParser = TypedPackedFunc<TargetJSON(TargetJSON)>;

/*!
* \brief RelayToTIR tvm::transform::Pass specific to a TargetKind
*
* Called before the default lowering passes.
*
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that can provide information for the optimization.
*
* \return The transformed module.
*/
using FTVMRelayToTIR = transform::Pass;

/*!
* \brief TIRToRuntime conversion specific to a TargetKind
*
* This function is responsible for scanning an IRModule for appropriate Target-specific functions
and generating a Runtime module representing the compiled output
*
* \param ir_module Unified IRModule
* \param target Target to filter on or retrieve arguments from
* \return Runtime Module containing compiled functions
*/
using FTVMTIRToRuntime = runtime::TypedPackedFunc<runtime::Module(IRModule, Target)>;
using FTVMTargetParser = runtime::TypedPackedFunc<TargetJSON(TargetJSON)>;

namespace detail {
template <typename, typename, typename>
Expand Down
1 change: 1 addition & 0 deletions include/tvm/tir/usmp/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <tvm/ir/expr.h>
#include <tvm/ir/memory_pools.h>
#include <tvm/ir/module.h>
#include <tvm/runtime/device_api.h>
#include <tvm/target/target.h>
#include <tvm/tir/stmt.h>
Expand Down
17 changes: 17 additions & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,23 @@ std::pair<IRModule, IRModule> SplitMixedModule(IRModule mod_mixed, const Target&
return {host_mod, device_mod};
}

/*!
* \brief Check and update host field of the given legacy heterogeneous targets and
* target host.Note that this function is for legacy target api compatibility issue only,
* not recommended for other use.
* \param ir_modules The pointer to a Map objects with keys being Target objects
* \param host The Target typed object for target host to be updated
*/
void CheckAndUpdateHostConsistency(Map<Target, IRModule>* targets, Target* host) {
Map<Target, IRModule> new_targets;
for (auto& it : *targets) {
auto target = it.first;
CheckAndUpdateHostConsistency(&target, host);
new_targets.Set(target, it.second);
}
*targets = new_targets;
}

runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
const Target& target_host_arg) {
std::vector<runtime::Module> device_modules;
Expand Down
3 changes: 2 additions & 1 deletion src/relay/backend/contrib/cmsisnn/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ namespace cmsisnn {

tvm::transform::Pass RelayToTIR();
runtime::Module TIRToRuntime(IRModule mod, Target target);
using FTVMTIRToRuntime = tvm::runtime::TypedPackedFunc<runtime::Module(IRModule, Target)>;

TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU)
.add_attr_option<Array<String>>("mattr")
.add_attr_option<String>("mcpu")
.add_attr_option<Bool>("debug_last_error")
.set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
.set_attr<relay::transform::FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime)
.set_target_parser(tvm::target::parsers::cpu::ParseTarget);

Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/codegen_c/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace contrib {
*/
TVM_REGISTER_TARGET_KIND("ccompiler", kDLCPU)
.set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true))
.set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, CCompilerPass())
.set_attr<relay::transform::FTVMRelayToTIR>(tvm::attr::kRelayToTIR, CCompilerPass())
// Value is prepended to every output CModule.
.add_attr_option<String>("header", String(""));

Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/cutlass/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace cutlass {
*/
TVM_REGISTER_TARGET_KIND("cutlass", kDLCUDA)
.set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true))
.set_attr<FTVMRelayToTIR>("RelayToTIR", CompileForCutlass())
.set_attr<tvm::transform::Pass>("RelayToTIR", CompileForCutlass())
// An integer specifying the compute capability. For example, 75 for Turing and
// 80 or 86 for Ampere.
.add_attr_option<Integer>("sm", Integer(80))
Expand Down
4 changes: 3 additions & 1 deletion src/relay/backend/contrib/ethosu/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ namespace relay {
namespace contrib {
namespace ethosu {

using FTVMTIRToRuntime = tvm::runtime::TypedPackedFunc<runtime::Module(IRModule, Target)>;

/*!
* \brief This mutator outlines functions that are marked with a named
* "Compiler" attribute. Functions that do not match this condition remain
Expand Down Expand Up @@ -320,7 +322,7 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {

TVM_REGISTER_TARGET_KIND("ethos-u", kDLCPU)
.set_attr<Bool>("use_device_api", Bool(true))
.set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
.set_attr<relay::transform::FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime);

} // namespace ethosu
Expand Down
5 changes: 4 additions & 1 deletion src/relay/backend/contrib/example_target_hooks/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

namespace tvm {

using FTVMTIRToRuntime = tvm::runtime::TypedPackedFunc<runtime::Module(IRModule, Target)>;

namespace relay {
namespace contrib {
namespace example_target_hooks {
Expand All @@ -33,7 +35,8 @@ runtime::Module TIRToRuntime(IRModule mod, Target target);

TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU)
.set_attr<Bool>("use_device_api", Bool(true))
.set_attr<FTVMRelayToTIR>(attr::kRelayToTIR, relay::contrib::example_target_hooks::RelayToTIR())
.set_attr<relay::transform::FTVMRelayToTIR>(attr::kRelayToTIR,
relay::contrib::example_target_hooks::RelayToTIR())
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime)
.add_attr_option<Integer>("example_attribute", Integer(0));

Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/tensorrt/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace tensorrt {
*/
TVM_REGISTER_TARGET_KIND("tensorrt", kDLCUDA)
.set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true))
.set_attr<FTVMRelayToTIR>("RelayToTIR", CompileForTensorRT())
.set_attr<tvm::transform::Pass>("RelayToTIR", CompileForTensorRT())
// A array of three integers given the major, minor, and patch numbers for the supported
// TensorRT compiler version. If empty will be auto-detected from linked library. Default empty.
.add_attr_option<Array<Integer>>("tensorrt_version", Array<Integer>())
Expand Down
6 changes: 4 additions & 2 deletions src/relay/backend/contrib/uma/targets.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

namespace tvm {

using FTVMTIRToRuntime = tvm::runtime::TypedPackedFunc<runtime::Module(IRModule, Target)>;

namespace relay {
namespace contrib {
namespace uma {
Expand Down Expand Up @@ -57,8 +59,8 @@ TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget")
.add_attr_option<Array<String>>("libs")
.add_attr_option<Target>("host")
.add_attr_option<Integer>("from_device")
.set_attr<FTVMRelayToTIR>(attr::kRelayToTIR,
relay::contrib::uma::RelayToTIR(target_name))
.set_attr<relay::transform::FTVMRelayToTIR>(
attr::kRelayToTIR, relay::contrib::uma::RelayToTIR(target_name))
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", relay::contrib::uma::TIRToRuntime);

// target kind attrs inventory
Expand Down
13 changes: 13 additions & 0 deletions src/target/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,21 @@
#include <vector>

namespace tvm {

namespace codegen {

/*!
* \brief TIRToRuntime conversion specific to a TargetKind
*
* This function is responsible for scanning an IRModule for appropriate Target-specific functions
and generating a Runtime module representing the compiled output
*
* \param ir_module Unified IRModule
* \param target Target to filter on or retrieve arguments from
* \return Runtime Module containing compiled functions
*/
using FTVMTIRToRuntime = tvm::runtime::TypedPackedFunc<runtime::Module(IRModule, Target)>;

runtime::Module Build(IRModule mod, Target target) {
if (transform::PassContext::Current()
->GetConfig<Bool>("tir.disable_assert", Bool(false))
Expand Down
15 changes: 3 additions & 12 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \file src/target/target.cc
*/
#include <dmlc/thread_local.h>
#include <tvm/ir/transform.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
Expand Down Expand Up @@ -91,16 +92,6 @@ void CheckAndUpdateHostConsistency(Target* target, Target* host) {
*host = (*target)->GetHost().value_or(Target());
}

void CheckAndUpdateHostConsistency(Map<Target, IRModule>* targets, Target* host) {
Map<Target, IRModule> new_targets;
for (auto& it : *targets) {
auto target = it.first;
CheckAndUpdateHostConsistency(&target, host);
new_targets.Set(target, it.second);
}
*targets = new_targets;
}

static std::vector<String> DeduplicateKeys(const std::vector<String>& keys) {
std::vector<String> new_keys;
for (size_t i = 0; i < keys.size(); ++i) {
Expand Down Expand Up @@ -614,8 +605,8 @@ Target::Target(TargetKind kind, Optional<ObjectRef> host, String tag, Array<Stri
bool Target::IsExternalCodegen() const {
TargetKindAttrMap<Bool> is_external_codegen_map =
TargetKind::GetAttrMap<Bool>(tvm::attr::kIsExternalCodegen);
TargetKindAttrMap<FTVMRelayToTIR> relay_to_tir_map =
TargetKind::GetAttrMap<FTVMRelayToTIR>(tvm::attr::kRelayToTIR);
TargetKindAttrMap<tvm::transform::Pass> relay_to_tir_map =
TargetKind::GetAttrMap<tvm::transform::Pass>(tvm::attr::kRelayToTIR);
return is_external_codegen_map.get(get()->kind, Bool(false)) ||
relay_to_tir_map.count(get()->kind);
}
Expand Down
3 changes: 2 additions & 1 deletion tests/cpp/target_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,8 @@ TVM_REGISTER_TARGET_KIND("test_external_codegen_2", kDLMetal)
.set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true));

TVM_REGISTER_TARGET_KIND("test_external_codegen_3", kDLCPU)
.set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, tvm::relay::transform::InferType());
.set_attr<tvm::relay::transform::FTVMRelayToTIR>(tvm::attr::kRelayToTIR,
tvm::relay::transform::InferType());

TEST(Target, ExternalCodegen) {
Target regular("cuda");
Expand Down