Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
2 changes: 1 addition & 1 deletion 3rdparty/cutlass_fpA_intB_gemm
6 changes: 5 additions & 1 deletion apps/cpp_rpc/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* \brief RPC Server implementation.
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/reflection.h>
#if defined(__linux__) || defined(__ANDROID__) || defined(__APPLE__)
#include <signal.h>
#include <sys/select.h>
Expand Down Expand Up @@ -398,6 +399,9 @@ void RPCServerCreate(std::string host, int port, int port_end, std::string track
rpc.Start();
}

TVM_FFI_REGISTER_GLOBAL("rpc.ServerCreate").set_body_typed(RPCServerCreate);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("rpc.ServerCreate", RPCServerCreate);
});
} // namespace runtime
} // namespace tvm
78 changes: 42 additions & 36 deletions apps/ios_rpc/tvmrpc/TVMRuntime.mm
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#import <Foundation/Foundation.h>

#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/reflection.h>

#include "RPCArgs.h"

Expand Down Expand Up @@ -51,38 +52,40 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s

} // namespace detail

TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.workpath")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) {
static const std::string base_ = NSTemporaryDirectory().UTF8String;
const auto path = args[0].cast<std::string>();
*rv = base_ + "/" + path;
});

TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.load_module")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) {
auto name = args[0].cast<std::string>();
std::string fmt = GetFileFormat(name, "");
NSString* base;
if (fmt == "dylib") {
// only load dylib from frameworks.
NSBundle* bundle = [NSBundle mainBundle];
base = [[bundle privateFrameworksPath] stringByAppendingPathComponent:@"tvm"];

if (tvm::ffi::Function::GetGlobal("runtime.module.loadfile_dylib_custom")) {
// Custom dso loader is present. Will use it.
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def_packed("tvm.rpc.server.workpath",
[](ffi::PackedArgs args, ffi::Any* rv) {
static const std::string base_ = NSTemporaryDirectory().UTF8String;
const auto path = args[0].cast<std::string>();
*rv = base_ + "/" + path;
})
.def_packed("tvm.rpc.server.load_module", [](ffi::PackedArgs args, ffi::Any* rv) {
auto name = args[0].cast<std::string>();
std::string fmt = GetFileFormat(name, "");
NSString* base;
if (fmt == "dylib") {
// only load dylib from frameworks.
NSBundle* bundle = [NSBundle mainBundle];
base = [[bundle privateFrameworksPath] stringByAppendingPathComponent:@"tvm"];

if (tvm::ffi::Function::GetGlobal("runtime.module.loadfile_dylib_custom")) {
// Custom dso loader is present. Will use it.
base = NSTemporaryDirectory();
fmt = "dylib_custom";
}
} else {
// Load other modules in tempdir.
base = NSTemporaryDirectory();
fmt = "dylib_custom";
}
} else {
// Load other modules in tempdir.
base = NSTemporaryDirectory();
}
NSString* path =
[base stringByAppendingPathComponent:[NSString stringWithUTF8String:name.c_str()]];
name = [path UTF8String];
*rv = Module::LoadFromFile(name, fmt);
LOG(INFO) << "Load module from " << name << " ...";
});
NSString* path =
[base stringByAppendingPathComponent:[NSString stringWithUTF8String:name.c_str()]];
name = [path UTF8String];
*rv = Module::LoadFromFile(name, fmt);
LOG(INFO) << "Load module from " << name << " ...";
});
});

#if defined(USE_CUSTOM_DSO_LOADER) && USE_CUSTOM_DSO_LOADER == 1

Expand All @@ -109,12 +112,15 @@ void Init(const std::string& name) {
};

// Add UnsignedDSOLoader plugin in global registry
TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_dylib_custom")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) {
auto n = make_object<UnsignedDSOLoader>();
n->Init(args[0]);
*rv = CreateModuleFromLibrary(n);
});
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed("runtime.module.loadfile_dylib_custom",
[](ffi::PackedArgs args, ffi::Any* rv) {
auto n = make_object<UnsignedDSOLoader>();
n->Init(args[0]);
*rv = CreateModuleFromLibrary(n);
});
});

#endif

Expand Down
191 changes: 97 additions & 94 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
*/
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/reflection.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>

Expand Down Expand Up @@ -269,100 +270,102 @@ PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) {
return res;
}

TVM_FFI_REGISTER_GLOBAL("arith.CreateAnalyzer")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
using ffi::Function;
using ffi::TypedFunction;
auto self = std::make_shared<Analyzer>();
auto f = [self](std::string name) -> ffi::Function {
if (name == "const_int_bound") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->const_int_bound(args[0].cast<PrimExpr>());
});
} else if (name == "modular_set") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->modular_set(args[0].cast<PrimExpr>());
});
} else if (name == "const_int_bound_update") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
self->const_int_bound.Update(args[0].cast<Var>(), args[1].cast<ConstIntBound>(),
args[2].cast<bool>());
});
} else if (name == "const_int_bound_is_bound") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->const_int_bound.IsBound(args[0].cast<Var>());
});
} else if (name == "Simplify") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
if (args.size() == 1) {
*ret = self->Simplify(args[0].cast<PrimExpr>());
} else if (args.size() == 2) {
*ret = self->Simplify(args[0].cast<PrimExpr>(), args[1].cast<int>());
} else {
LOG(FATAL) << "Invalid size of argument (" << args.size() << ")";
}
});
} else if (name == "rewrite_simplify") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->rewrite_simplify(args[0].cast<PrimExpr>());
});
} else if (name == "get_rewrite_simplify_stats") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->rewrite_simplify.GetStatsCounters();
});
} else if (name == "reset_rewrite_simplify_stats") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
self->rewrite_simplify.ResetStatsCounters();
});
} else if (name == "canonical_simplify") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->canonical_simplify(args[0].cast<PrimExpr>());
});
} else if (name == "int_set") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->int_set(args[0].cast<PrimExpr>(), args[1].cast<Map<Var, IntSet>>());
});
} else if (name == "bind") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
if (auto opt_range = args[1].try_cast<Range>()) {
self->Bind(args[0].cast<Var>(), opt_range.value());
} else {
self->Bind(args[0].cast<Var>(), args[1].cast<PrimExpr>());
}
});
} else if (name == "can_prove") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
int strength = args[1].cast<int>();
*ret = self->CanProve(args[0].cast<PrimExpr>(), static_cast<ProofStrength>(strength));
});
} else if (name == "enter_constraint_context") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
// can't use make_shared due to noexcept(false) decl in destructor,
// see https://stackoverflow.com/a/43907314
auto ctx = std::shared_ptr<With<ConstraintContext>>(
new With<ConstraintContext>(self.get(), args[0].cast<PrimExpr>()));
auto fexit = [ctx](ffi::PackedArgs, ffi::Any*) mutable { ctx.reset(); };
*ret = ffi::Function::FromPacked(fexit);
});
} else if (name == "can_prove_equal") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->CanProveEqual(args[0].cast<PrimExpr>(), args[1].cast<PrimExpr>());
});
} else if (name == "get_enabled_extensions") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = static_cast<std::int64_t>(self->rewrite_simplify.GetEnabledExtensions());
});
} else if (name == "set_enabled_extensions") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
int64_t flags = args[0].cast<int64_t>();
self->rewrite_simplify.SetEnabledExtensions(
static_cast<RewriteSimplifier::Extension>(flags));
});
}
return ffi::Function();
};
*ret = ffi::TypedFunction<ffi::Function(std::string)>(f);
});
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed("arith.CreateAnalyzer", [](ffi::PackedArgs args, ffi::Any* ret) {
using ffi::Function;
using ffi::TypedFunction;
auto self = std::make_shared<Analyzer>();
auto f = [self](std::string name) -> ffi::Function {
if (name == "const_int_bound") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->const_int_bound(args[0].cast<PrimExpr>());
});
} else if (name == "modular_set") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->modular_set(args[0].cast<PrimExpr>());
});
} else if (name == "const_int_bound_update") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
self->const_int_bound.Update(args[0].cast<Var>(), args[1].cast<ConstIntBound>(),
args[2].cast<bool>());
});
} else if (name == "const_int_bound_is_bound") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->const_int_bound.IsBound(args[0].cast<Var>());
});
} else if (name == "Simplify") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
if (args.size() == 1) {
*ret = self->Simplify(args[0].cast<PrimExpr>());
} else if (args.size() == 2) {
*ret = self->Simplify(args[0].cast<PrimExpr>(), args[1].cast<int>());
} else {
LOG(FATAL) << "Invalid size of argument (" << args.size() << ")";
}
});
} else if (name == "rewrite_simplify") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->rewrite_simplify(args[0].cast<PrimExpr>());
});
} else if (name == "get_rewrite_simplify_stats") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->rewrite_simplify.GetStatsCounters();
});
} else if (name == "reset_rewrite_simplify_stats") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
self->rewrite_simplify.ResetStatsCounters();
});
} else if (name == "canonical_simplify") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->canonical_simplify(args[0].cast<PrimExpr>());
});
} else if (name == "int_set") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->int_set(args[0].cast<PrimExpr>(), args[1].cast<Map<Var, IntSet>>());
});
} else if (name == "bind") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
if (auto opt_range = args[1].try_cast<Range>()) {
self->Bind(args[0].cast<Var>(), opt_range.value());
} else {
self->Bind(args[0].cast<Var>(), args[1].cast<PrimExpr>());
}
});
} else if (name == "can_prove") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
int strength = args[1].cast<int>();
*ret = self->CanProve(args[0].cast<PrimExpr>(), static_cast<ProofStrength>(strength));
});
} else if (name == "enter_constraint_context") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
// can't use make_shared due to noexcept(false) decl in destructor,
// see https://stackoverflow.com/a/43907314
auto ctx = std::shared_ptr<With<ConstraintContext>>(
new With<ConstraintContext>(self.get(), args[0].cast<PrimExpr>()));
auto fexit = [ctx](ffi::PackedArgs, ffi::Any*) mutable { ctx.reset(); };
*ret = ffi::Function::FromPacked(fexit);
});
} else if (name == "can_prove_equal") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->CanProveEqual(args[0].cast<PrimExpr>(), args[1].cast<PrimExpr>());
});
} else if (name == "get_enabled_extensions") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = static_cast<std::int64_t>(self->rewrite_simplify.GetEnabledExtensions());
});
} else if (name == "set_enabled_extensions") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
int64_t flags = args[0].cast<int64_t>();
self->rewrite_simplify.SetEnabledExtensions(
static_cast<RewriteSimplifier::Extension>(flags));
});
}
return ffi::Function();
};
*ret = ffi::TypedFunction<ffi::Function(std::string)>(f);
});
});

} // namespace arith
} // namespace tvm
13 changes: 8 additions & 5 deletions src/arith/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/reflection.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>

Expand Down Expand Up @@ -402,11 +403,13 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e, const Map<Var, IntSet>& hint_map,
return DeduceBound(v, e, hmap, rmap);
}

TVM_FFI_REGISTER_GLOBAL("arith.DeduceBound")
.set_body_typed([](PrimExpr v, PrimExpr cond, const Map<Var, IntSet> hint_map,
const Map<Var, IntSet> relax_map) {
return DeduceBound(v, cond, hint_map, relax_map);
});
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def(
"arith.DeduceBound",
[](PrimExpr v, PrimExpr cond, const Map<Var, IntSet> hint_map,
const Map<Var, IntSet> relax_map) { return DeduceBound(v, cond, hint_map, relax_map); });
});

} // namespace arith
} // namespace tvm
6 changes: 5 additions & 1 deletion src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
*/
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/reflection.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr_functor.h>

Expand Down Expand Up @@ -53,7 +54,10 @@ ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
return ConstIntBound(min_value, max_value);
}

TVM_FFI_REGISTER_GLOBAL("arith.ConstIntBound").set_body_typed(MakeConstIntBound);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("arith.ConstIntBound", MakeConstIntBound);
});

inline void PrintBoundValue(std::ostream& os, int64_t val) {
if (val == ConstIntBound::kPosInf) {
Expand Down
6 changes: 5 additions & 1 deletion src/arith/detect_common_subexpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \file detect_common_subexpr.cc
* \brief Utility to detect common sub expressions.
*/
#include <tvm/ffi/reflection/reflection.h>
#include <tvm/tir/expr.h>

#include <limits>
Expand Down Expand Up @@ -69,6 +70,9 @@ Map<PrimExpr, Integer> DetectCommonSubExpr(const PrimExpr& e, int thresh) {
return results;
}

TVM_FFI_REGISTER_GLOBAL("arith.DetectCommonSubExpr").set_body_typed(DetectCommonSubExpr);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("arith.DetectCommonSubExpr", DetectCommonSubExpr);
});
} // namespace arith
} // namespace tvm
Loading
Loading