diff --git a/3rdparty/cutlass_fpA_intB_gemm b/3rdparty/cutlass_fpA_intB_gemm index b71d94a4ccd6..0697a5113661 160000 --- a/3rdparty/cutlass_fpA_intB_gemm +++ b/3rdparty/cutlass_fpA_intB_gemm @@ -1 +1 @@ -Subproject commit b71d94a4ccd6573c9cbd4056c9ce660f110d33f0 +Subproject commit 0697a511366194fc305649da0746308439fd7a75 diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index 2f74dd309f42..07fa000ffc63 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -22,6 +22,7 @@ * \brief RPC Server implementation. */ #include +#include #if defined(__linux__) || defined(__ANDROID__) || defined(__APPLE__) #include #include @@ -398,6 +399,9 @@ void RPCServerCreate(std::string host, int port, int port_end, std::string track rpc.Start(); } -TVM_FFI_REGISTER_GLOBAL("rpc.ServerCreate").set_body_typed(RPCServerCreate); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("rpc.ServerCreate", RPCServerCreate); +}); } // namespace runtime } // namespace tvm diff --git a/apps/ios_rpc/tvmrpc/TVMRuntime.mm b/apps/ios_rpc/tvmrpc/TVMRuntime.mm index 8d0ae7368d8a..213a12539d93 100644 --- a/apps/ios_rpc/tvmrpc/TVMRuntime.mm +++ b/apps/ios_rpc/tvmrpc/TVMRuntime.mm @@ -24,6 +24,7 @@ #import #include +#include #include "RPCArgs.h" @@ -51,38 +52,40 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s } // namespace detail -TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.workpath") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - static const std::string base_ = NSTemporaryDirectory().UTF8String; - const auto path = args[0].cast(); - *rv = base_ + "/" + path; - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.load_module") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - auto name = args[0].cast(); - std::string fmt = GetFileFormat(name, ""); - NSString* base; - if (fmt == "dylib") { - // only load dylib from frameworks. - NSBundle* bundle = [NSBundle mainBundle]; - base = [[bundle privateFrameworksPath] stringByAppendingPathComponent:@"tvm"]; - - if (tvm::ffi::Function::GetGlobal("runtime.module.loadfile_dylib_custom")) { - // Custom dso loader is present. Will use it. +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("tvm.rpc.server.workpath", + [](ffi::PackedArgs args, ffi::Any* rv) { + static const std::string base_ = NSTemporaryDirectory().UTF8String; + const auto path = args[0].cast(); + *rv = base_ + "/" + path; + }) + .def_packed("tvm.rpc.server.load_module", [](ffi::PackedArgs args, ffi::Any* rv) { + auto name = args[0].cast(); + std::string fmt = GetFileFormat(name, ""); + NSString* base; + if (fmt == "dylib") { + // only load dylib from frameworks. + NSBundle* bundle = [NSBundle mainBundle]; + base = [[bundle privateFrameworksPath] stringByAppendingPathComponent:@"tvm"]; + + if (tvm::ffi::Function::GetGlobal("runtime.module.loadfile_dylib_custom")) { + // Custom dso loader is present. Will use it. + base = NSTemporaryDirectory(); + fmt = "dylib_custom"; + } + } else { + // Load other modules in tempdir. base = NSTemporaryDirectory(); - fmt = "dylib_custom"; } - } else { - // Load other modules in tempdir. - base = NSTemporaryDirectory(); - } - NSString* path = - [base stringByAppendingPathComponent:[NSString stringWithUTF8String:name.c_str()]]; - name = [path UTF8String]; - *rv = Module::LoadFromFile(name, fmt); - LOG(INFO) << "Load module from " << name << " ..."; - }); + NSString* path = + [base stringByAppendingPathComponent:[NSString stringWithUTF8String:name.c_str()]]; + name = [path UTF8String]; + *rv = Module::LoadFromFile(name, fmt); + LOG(INFO) << "Load module from " << name << " ..."; + }); +}); #if defined(USE_CUSTOM_DSO_LOADER) && USE_CUSTOM_DSO_LOADER == 1 @@ -109,12 +112,15 @@ void Init(const std::string& name) { }; // Add UnsignedDSOLoader plugin in global registry -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_dylib_custom") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - auto n = make_object(); - n->Init(args[0]); - *rv = CreateModuleFromLibrary(n); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("runtime.module.loadfile_dylib_custom", + [](ffi::PackedArgs args, ffi::Any* rv) { + auto n = make_object(); + n->Init(args[0]); + *rv = CreateModuleFromLibrary(n); + }); +}); #endif diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 89cdf1c27876..62b54ac81e35 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include @@ -269,100 +270,102 @@ PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { return res; } -TVM_FFI_REGISTER_GLOBAL("arith.CreateAnalyzer") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - using ffi::Function; - using ffi::TypedFunction; - auto self = std::make_shared(); - auto f = [self](std::string name) -> ffi::Function { - if (name == "const_int_bound") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->const_int_bound(args[0].cast()); - }); - } else if (name == "modular_set") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->modular_set(args[0].cast()); - }); - } else if (name == "const_int_bound_update") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - self->const_int_bound.Update(args[0].cast(), args[1].cast(), - args[2].cast()); - }); - } else if (name == "const_int_bound_is_bound") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->const_int_bound.IsBound(args[0].cast()); - }); - } else if (name == "Simplify") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - if (args.size() == 1) { - *ret = self->Simplify(args[0].cast()); - } else if (args.size() == 2) { - *ret = self->Simplify(args[0].cast(), args[1].cast()); - } else { - LOG(FATAL) << "Invalid size of argument (" << args.size() << ")"; - } - }); - } else if (name == "rewrite_simplify") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->rewrite_simplify(args[0].cast()); - }); - } else if (name == "get_rewrite_simplify_stats") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->rewrite_simplify.GetStatsCounters(); - }); - } else if (name == "reset_rewrite_simplify_stats") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - self->rewrite_simplify.ResetStatsCounters(); - }); - } else if (name == "canonical_simplify") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->canonical_simplify(args[0].cast()); - }); - } else if (name == "int_set") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->int_set(args[0].cast(), args[1].cast>()); - }); - } else if (name == "bind") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - if (auto opt_range = args[1].try_cast()) { - self->Bind(args[0].cast(), opt_range.value()); - } else { - self->Bind(args[0].cast(), args[1].cast()); - } - }); - } else if (name == "can_prove") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - int strength = args[1].cast(); - *ret = self->CanProve(args[0].cast(), static_cast(strength)); - }); - } else if (name == "enter_constraint_context") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - // can't use make_shared due to noexcept(false) decl in destructor, - // see https://stackoverflow.com/a/43907314 - auto ctx = std::shared_ptr>( - new With(self.get(), args[0].cast())); - auto fexit = [ctx](ffi::PackedArgs, ffi::Any*) mutable { ctx.reset(); }; - *ret = ffi::Function::FromPacked(fexit); - }); - } else if (name == "can_prove_equal") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->CanProveEqual(args[0].cast(), args[1].cast()); - }); - } else if (name == "get_enabled_extensions") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = static_cast(self->rewrite_simplify.GetEnabledExtensions()); - }); - } else if (name == "set_enabled_extensions") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - int64_t flags = args[0].cast(); - self->rewrite_simplify.SetEnabledExtensions( - static_cast(flags)); - }); - } - return ffi::Function(); - }; - *ret = ffi::TypedFunction(f); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("arith.CreateAnalyzer", [](ffi::PackedArgs args, ffi::Any* ret) { + using ffi::Function; + using ffi::TypedFunction; + auto self = std::make_shared(); + auto f = [self](std::string name) -> ffi::Function { + if (name == "const_int_bound") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + *ret = self->const_int_bound(args[0].cast()); + }); + } else if (name == "modular_set") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + *ret = self->modular_set(args[0].cast()); + }); + } else if (name == "const_int_bound_update") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + self->const_int_bound.Update(args[0].cast(), args[1].cast(), + args[2].cast()); + }); + } else if (name == "const_int_bound_is_bound") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + *ret = self->const_int_bound.IsBound(args[0].cast()); + }); + } else if (name == "Simplify") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + if (args.size() == 1) { + *ret = self->Simplify(args[0].cast()); + } else if (args.size() == 2) { + *ret = self->Simplify(args[0].cast(), args[1].cast()); + } else { + LOG(FATAL) << "Invalid size of argument (" << args.size() << ")"; + } + }); + } else if (name == "rewrite_simplify") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + *ret = self->rewrite_simplify(args[0].cast()); + }); + } else if (name == "get_rewrite_simplify_stats") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + *ret = self->rewrite_simplify.GetStatsCounters(); + }); + } else if (name == "reset_rewrite_simplify_stats") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + self->rewrite_simplify.ResetStatsCounters(); + }); + } else if (name == "canonical_simplify") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + *ret = self->canonical_simplify(args[0].cast()); + }); + } else if (name == "int_set") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + *ret = self->int_set(args[0].cast(), args[1].cast>()); + }); + } else if (name == "bind") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + if (auto opt_range = args[1].try_cast()) { + self->Bind(args[0].cast(), opt_range.value()); + } else { + self->Bind(args[0].cast(), args[1].cast()); + } + }); + } else if (name == "can_prove") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + int strength = args[1].cast(); + *ret = self->CanProve(args[0].cast(), static_cast(strength)); + }); + } else if (name == "enter_constraint_context") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + // can't use make_shared due to noexcept(false) decl in destructor, + // see https://stackoverflow.com/a/43907314 + auto ctx = std::shared_ptr>( + new With(self.get(), args[0].cast())); + auto fexit = [ctx](ffi::PackedArgs, ffi::Any*) mutable { ctx.reset(); }; + *ret = ffi::Function::FromPacked(fexit); + }); + } else if (name == "can_prove_equal") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + *ret = self->CanProveEqual(args[0].cast(), args[1].cast()); + }); + } else if (name == "get_enabled_extensions") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + *ret = static_cast(self->rewrite_simplify.GetEnabledExtensions()); + }); + } else if (name == "set_enabled_extensions") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + int64_t flags = args[0].cast(); + self->rewrite_simplify.SetEnabledExtensions( + static_cast(flags)); + }); + } + return ffi::Function(); + }; + *ret = ffi::TypedFunction(f); + }); +}); } // namespace arith } // namespace tvm diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index b8b5d6482428..12d8c8710ee3 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include @@ -402,11 +403,13 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e, const Map& hint_map, return DeduceBound(v, e, hmap, rmap); } -TVM_FFI_REGISTER_GLOBAL("arith.DeduceBound") - .set_body_typed([](PrimExpr v, PrimExpr cond, const Map hint_map, - const Map relax_map) { - return DeduceBound(v, cond, hint_map, relax_map); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "arith.DeduceBound", + [](PrimExpr v, PrimExpr cond, const Map hint_map, + const Map relax_map) { return DeduceBound(v, cond, hint_map, relax_map); }); +}); } // namespace arith } // namespace tvm diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index b57c04752ff2..e538a11025b4 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include @@ -53,7 +54,10 @@ ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) { return ConstIntBound(min_value, max_value); } -TVM_FFI_REGISTER_GLOBAL("arith.ConstIntBound").set_body_typed(MakeConstIntBound); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("arith.ConstIntBound", MakeConstIntBound); +}); inline void PrintBoundValue(std::ostream& os, int64_t val) { if (val == ConstIntBound::kPosInf) { diff --git a/src/arith/detect_common_subexpr.cc b/src/arith/detect_common_subexpr.cc index 303360002e03..e4bc5eab771e 100644 --- a/src/arith/detect_common_subexpr.cc +++ b/src/arith/detect_common_subexpr.cc @@ -21,6 +21,7 @@ * \file detect_common_subexpr.cc * \brief Utility to detect common sub expressions. */ +#include #include #include @@ -69,6 +70,9 @@ Map DetectCommonSubExpr(const PrimExpr& e, int thresh) { return results; } -TVM_FFI_REGISTER_GLOBAL("arith.DetectCommonSubExpr").set_body_typed(DetectCommonSubExpr); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("arith.DetectCommonSubExpr", DetectCommonSubExpr); +}); } // namespace arith } // namespace tvm diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index 0dcbc7623590..a7bc273ad957 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -290,11 +291,12 @@ Array DetectClipBound(const PrimExpr& e, const Array& vars) { return ret; } -TVM_FFI_REGISTER_GLOBAL("arith.DetectLinearEquation").set_body_typed(DetectLinearEquation); - -TVM_FFI_REGISTER_GLOBAL("arith.DetectClipBound") - .set_body_typed([](const PrimExpr& e, const Array& vars) { - return DetectClipBound(e, vars); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("arith.DetectLinearEquation", DetectLinearEquation) + .def("arith.DetectClipBound", + [](const PrimExpr& e, const Array& vars) { return DetectClipBound(e, vars); }); +}); } // namespace arith } // namespace tvm diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index 5f9d78003001..7c102094c43e 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -22,6 +22,7 @@ * \brief Utility to deduce bound of expression */ #include +#include #include #include #include @@ -162,8 +163,12 @@ Map> DomainTouchedAccessMap(const PrimFunc& func) { return ret; } -TVM_FFI_REGISTER_GLOBAL("arith.DomainTouched").set_body_typed(DomainTouched); -TVM_FFI_REGISTER_GLOBAL("arith.DomainTouchedAccessMap").set_body_typed(DomainTouchedAccessMap); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("arith.DomainTouched", DomainTouched) + .def("arith.DomainTouchedAccessMap", DomainTouchedAccessMap); +}); } // namespace arith } // namespace tvm diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index afe7f09676b6..1fadbad07464 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -201,25 +202,24 @@ Range IntGroupBounds::FindBestRange(const Map& vranges_addl) const { TVM_REGISTER_NODE_TYPE(IntGroupBoundsNode); -TVM_FFI_REGISTER_GLOBAL("arith.IntGroupBounds") - .set_body_typed([](PrimExpr coef, Array lower, Array equal, - Array upper) { - return IntGroupBounds(coef, lower, equal, upper); - }); - -TVM_FFI_REGISTER_GLOBAL("arith.IntGroupBounds_from_range") - .set_body_typed(IntGroupBounds::FromRange); - -TVM_FFI_REGISTER_GLOBAL("arith.IntGroupBounds_FindBestRange") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - ICHECK(args.size() == 1 || args.size() == 2); - auto bounds = args[0].cast(); - if (args.size() == 1) { - *ret = bounds.FindBestRange(); - } else if (args.size() == 2) { - *ret = bounds.FindBestRange(args[1].cast>()); - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("arith.IntGroupBounds", + [](PrimExpr coef, Array lower, Array equal, Array upper) { + return IntGroupBounds(coef, lower, equal, upper); + }) + .def("arith.IntGroupBounds_from_range", IntGroupBounds::FromRange) + .def_packed("arith.IntGroupBounds_FindBestRange", [](ffi::PackedArgs args, ffi::Any* ret) { + ICHECK(args.size() == 1 || args.size() == 2); + auto bounds = args[0].cast(); + if (args.size() == 1) { + *ret = bounds.FindBestRange(); + } else if (args.size() == 2) { + *ret = bounds.FindBestRange(args[1].cast>()); + } + }); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -250,10 +250,13 @@ IntConstraints::IntConstraints(Array variables, Map ranges, TVM_REGISTER_NODE_TYPE(IntConstraintsNode); -TVM_FFI_REGISTER_GLOBAL("arith.IntConstraints") - .set_body_typed([](Array variables, Map ranges, Array relations) { - return IntConstraints(variables, ranges, relations); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("arith.IntConstraints", [](Array variables, Map ranges, + Array relations) { + return IntConstraints(variables, ranges, relations); + }); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -295,11 +298,14 @@ IntConstraintsTransform IntConstraintsTransform::operator+( TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode); -TVM_FFI_REGISTER_GLOBAL("arith.IntConstraintsTransform") - .set_body_typed([](IntConstraints src, IntConstraints dst, Map src_to_dst, - Map dst_to_src) { - return IntConstraintsTransform(src, dst, src_to_dst, dst_to_src); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("arith.IntConstraintsTransform", + [](IntConstraints src, IntConstraints dst, Map src_to_dst, + Map dst_to_src) { + return IntConstraintsTransform(src, dst, src_to_dst, dst_to_src); + }); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 0a347040b76b..565819c00fd1 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -59,7 +60,10 @@ IntervalSet MakeIntervalSet(PrimExpr min_value, PrimExpr max_value) { return IntervalSet(min_value, max_value); } -TVM_FFI_REGISTER_GLOBAL("arith.IntervalSet").set_body_typed(MakeIntervalSet); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("arith.IntervalSet", MakeIntervalSet); +}); IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { PrimExpr max_value = min(a->max_value, b->max_value); @@ -1194,42 +1198,38 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << "[" << op->min_value << ", " << op->max_value << ']'; }); -TVM_FFI_REGISTER_GLOBAL("arith.intset_single_point").set_body_typed(IntSet::SinglePoint); - -TVM_FFI_REGISTER_GLOBAL("arith.intset_vector").set_body_typed(IntSet::Vector); - -TVM_FFI_REGISTER_GLOBAL("arith.intset_interval").set_body_typed(IntSet::Interval); - -TVM_FFI_REGISTER_GLOBAL("arith.IntervalSetGetMin").set_body_method(&IntSet::min); - -TVM_FFI_REGISTER_GLOBAL("arith.IntervalSetGetMax").set_body_method(&IntSet::max); - -TVM_FFI_REGISTER_GLOBAL("arith.IntSetIsNothing").set_body_method(&IntSet::IsNothing); - -TVM_FFI_REGISTER_GLOBAL("arith.IntSetIsEverything").set_body_method(&IntSet::IsEverything); - -TVM_FFI_REGISTER_GLOBAL("arith.EstimateRegionLowerBound") - .set_body_typed([](Array region, Map var_dom, - PrimExpr predicate) -> Optional> { - Analyzer analyzer; - return EstimateRegionLowerBound(region, var_dom, predicate, &analyzer); - }); -TVM_FFI_REGISTER_GLOBAL("arith.EstimateRegionStrictBound") - .set_body_typed([](Array region, Map var_dom, - PrimExpr predicate) -> Optional> { - Analyzer analyzer; - return EstimateRegionStrictBound(region, var_dom, predicate, &analyzer); - }); -TVM_FFI_REGISTER_GLOBAL("arith.EstimateRegionUpperBound") - .set_body_typed([](Array region, Map var_dom, - PrimExpr predicate) -> Optional> { - Analyzer analyzer; - return EstimateRegionUpperBound(region, var_dom, predicate, &analyzer); - }); - -TVM_FFI_REGISTER_GLOBAL("arith.PosInf").set_body_typed([]() { return SymbolicLimits::pos_inf_; }); -TVM_FFI_REGISTER_GLOBAL("arith.NegInf").set_body_typed([]() { return SymbolicLimits::neg_inf_; }); -TVM_FFI_REGISTER_GLOBAL("arith.UnionLowerBound").set_body_typed(UnionLowerBound); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("arith.intset_single_point", IntSet::SinglePoint) + .def("arith.intset_vector", IntSet::Vector) + .def("arith.intset_interval", IntSet::Interval) + .def_method("arith.IntervalSetGetMin", &IntSet::min) + .def_method("arith.IntervalSetGetMax", &IntSet::max) + .def_method("arith.IntSetIsNothing", &IntSet::IsNothing) + .def_method("arith.IntSetIsEverything", &IntSet::IsEverything) + .def("arith.EstimateRegionLowerBound", + [](Array region, Map var_dom, + PrimExpr predicate) -> Optional> { + Analyzer analyzer; + return EstimateRegionLowerBound(region, var_dom, predicate, &analyzer); + }) + .def("arith.EstimateRegionStrictBound", + [](Array region, Map var_dom, + PrimExpr predicate) -> Optional> { + Analyzer analyzer; + return EstimateRegionStrictBound(region, var_dom, predicate, &analyzer); + }) + .def("arith.EstimateRegionUpperBound", + [](Array region, Map var_dom, + PrimExpr predicate) -> Optional> { + Analyzer analyzer; + return EstimateRegionUpperBound(region, var_dom, predicate, &analyzer); + }) + .def("arith.PosInf", []() { return SymbolicLimits::pos_inf_; }) + .def("arith.NegInf", []() { return SymbolicLimits::neg_inf_; }) + .def("arith.UnionLowerBound", UnionLowerBound); +}); } // namespace arith } // namespace tvm diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 52ed71edeac3..4aa48843dc88 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include #include @@ -54,8 +55,10 @@ IterMark::IterMark(PrimExpr source, PrimExpr extent) { data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("arith.IterMark").set_body_typed([](PrimExpr source, PrimExpr extent) { - return IterMark(source, extent); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("arith.IterMark", + [](PrimExpr source, PrimExpr extent) { return IterMark(source, extent); }); }); TVM_REGISTER_NODE_TYPE(IterMarkNode); @@ -99,10 +102,13 @@ IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr ex data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("arith.IterSplitExpr") - .set_body_typed([](IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) { - return IterSplitExpr(source, lower_factor, extent, scale); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("arith.IterSplitExpr", [](IterMark source, PrimExpr lower_factor, + PrimExpr extent, PrimExpr scale) { + return IterSplitExpr(source, lower_factor, extent, scale); + }); +}); TVM_REGISTER_NODE_TYPE(IterSplitExprNode); @@ -121,10 +127,12 @@ IterSumExpr::IterSumExpr(Array args, PrimExpr base) { data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("arith.IterSumExpr") - .set_body_typed([](Array args, PrimExpr base) { - return IterSumExpr(args, base); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("arith.IterSumExpr", [](Array args, PrimExpr base) { + return IterSumExpr(args, base); + }); +}); TVM_REGISTER_NODE_TYPE(IterSumExprNode); @@ -1520,14 +1528,17 @@ IterMapResult DetectIterMap(const Array& indices, const Map& indices, const Map& input_iters, - const PrimExpr& input_pred, int check_level, - bool simplify_trivial_iterators) { - arith::Analyzer ana; - return DetectIterMap(indices, input_iters, input_pred, IterMapLevel(check_level), &ana, - simplify_trivial_iterators); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "arith.DetectIterMap", + [](const Array& indices, const Map& input_iters, + const PrimExpr& input_pred, int check_level, bool simplify_trivial_iterators) { + arith::Analyzer ana; + return DetectIterMap(indices, input_iters, input_pred, IterMapLevel(check_level), &ana, + simplify_trivial_iterators); + }); +}); IterSumExpr NormalizeToIterSum(PrimExpr index, const Map& input_iters, arith::Analyzer* analyzer) { @@ -1545,11 +1556,14 @@ IterSumExpr NormalizeToIterSum(PrimExpr index, const Map& input_iter return rewriter.RewriteToNormalizedIterSum(index); } -TVM_FFI_REGISTER_GLOBAL("arith.NormalizeToIterSum") - .set_body_typed([](PrimExpr index, const Map& input_iters) { - arith::Analyzer ana; - return NormalizeToIterSum(index, input_iters, &ana); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("arith.NormalizeToIterSum", + [](PrimExpr index, const Map& input_iters) { + arith::Analyzer ana; + return NormalizeToIterSum(index, input_iters, &ana); + }); +}); PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) { auto var = GetRef(op); @@ -2144,7 +2158,10 @@ PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr) { return normalizer.Convert(expr); } -TVM_FFI_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed(NormalizeIterMapToExpr); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("arith.NormalizeIterMapToExpr", NormalizeIterMapToExpr); +}); Array IterMapSimplify(const Array& indices, const Map& input_iters, const PrimExpr& input_pred, IterMapLevel check_level, @@ -2173,14 +2190,17 @@ Array IterMapSimplify(const Array& indices, const Map& indices, const Map& input_iters, - const PrimExpr& input_pred, int check_level, - bool simplify_trivial_iterators) { - arith::Analyzer ana; - return IterMapSimplify(indices, input_iters, input_pred, IterMapLevel(check_level), &ana, - simplify_trivial_iterators); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "arith.IterMapSimplify", + [](const Array& indices, const Map& input_iters, + const PrimExpr& input_pred, int check_level, bool simplify_trivial_iterators) { + arith::Analyzer ana; + return IterMapSimplify(indices, input_iters, input_pred, IterMapLevel(check_level), &ana, + simplify_trivial_iterators); + }); +}); /*! * \brief Divider to divide the bindings into two sets of bindings(outer and inner) @@ -2506,14 +2526,17 @@ Array> SubspaceDivide(const Array& bindings, return results; } -TVM_FFI_REGISTER_GLOBAL("arith.SubspaceDivide") - .set_body_typed([](const Array& bindings, const Map& root_iters, - const Array& sub_iters, const PrimExpr& predicate, int check_level, - bool simplify_trivial_iterators) { - arith::Analyzer ana; - return SubspaceDivide(bindings, root_iters, sub_iters, predicate, IterMapLevel(check_level), - &ana, simplify_trivial_iterators); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "arith.SubspaceDivide", [](const Array& bindings, const Map& root_iters, + const Array& sub_iters, const PrimExpr& predicate, + int check_level, bool simplify_trivial_iterators) { + arith::Analyzer ana; + return SubspaceDivide(bindings, root_iters, sub_iters, predicate, IterMapLevel(check_level), + &ana, simplify_trivial_iterators); + }); +}); class InverseAffineIterMapTransformer { public: @@ -2645,7 +2668,10 @@ Map InverseAffineIterMap(const Array& iter_map, return InverseAffineIterMapTransformer(&analyzer)(iter_map, outputs); } -TVM_FFI_REGISTER_GLOBAL("arith.InverseAffineIterMap").set_body_typed(InverseAffineIterMap); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("arith.InverseAffineIterMap", InverseAffineIterMap); +}); TVM_REGISTER_NODE_TYPE(IterMapResultNode); diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index e4170f6c3c68..9f755c30e819 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -59,7 +60,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ModularSet MakeModularSet(int64_t coeff, int64_t base) { return ModularSet(coeff, base); } -TVM_FFI_REGISTER_GLOBAL("arith.ModularSet").set_body_typed(MakeModularSet); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("arith.ModularSet", MakeModularSet); +}); // internal entry for const int bound struct ModularSetAnalyzer::Entry { diff --git a/src/arith/narrow_predicate_expression.cc b/src/arith/narrow_predicate_expression.cc index e998ba65f354..1d67308da2e5 100644 --- a/src/arith/narrow_predicate_expression.cc +++ b/src/arith/narrow_predicate_expression.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -213,8 +214,10 @@ PrimExpr NarrowPredicateExpression(PrimExpr expr, Map free_parameter return ExpressionNarrower::Apply(std::move(expr), std::move(free_parameters)); } -TVM_FFI_REGISTER_GLOBAL("arith.NarrowPredicateExpression") - .set_body_typed(NarrowPredicateExpression); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("arith.NarrowPredicateExpression", NarrowPredicateExpression); +}); } // namespace arith } // namespace tvm diff --git a/src/arith/presburger_set.cc b/src/arith/presburger_set.cc index 1af1bb2e39bf..5ff97f391a94 100644 --- a/src/arith/presburger_set.cc +++ b/src/arith/presburger_set.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -274,7 +275,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) PresburgerSet MakePresburgerSet(const PrimExpr& constraint) { return PresburgerSet(constraint); } -TVM_FFI_REGISTER_GLOBAL("arith.PresburgerSet").set_body_typed(MakePresburgerSet); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("arith.PresburgerSet", MakePresburgerSet); +}); TVM_REGISTER_NODE_TYPE(PresburgerSetNode); diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index 4d90c61ea3cb..09c06395000a 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -454,21 +455,24 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol return transform; } -TVM_FFI_REGISTER_GLOBAL("arith.SolveLinearEquations") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - if (args.size() == 1) { - *ret = SolveLinearEquations(args[0].cast()); - } else if (args.size() == 3) { - auto opt_vars = args[0].cast>>(); - auto opt_map = args[1].cast>>(); - auto opt_relations = args[2].cast>>(); - IntConstraints problem(opt_vars.value_or({}), opt_map.value_or({}), - opt_relations.value_or({})); - *ret = SolveLinearEquations(problem); - } else { - LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size(); - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed( + "arith.SolveLinearEquations", [](ffi::PackedArgs args, ffi::Any* ret) { + if (args.size() == 1) { + *ret = SolveLinearEquations(args[0].cast()); + } else if (args.size() == 3) { + auto opt_vars = args[0].cast>>(); + auto opt_map = args[1].cast>>(); + auto opt_relations = args[2].cast>>(); + IntConstraints problem(opt_vars.value_or({}), opt_map.value_or({}), + opt_relations.value_or({})); + *ret = SolveLinearEquations(problem); + } else { + LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size(); + } + }); +}); } // namespace arith } // namespace tvm diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 62f314d1902f..352f34c1ec6c 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -535,53 +536,55 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ return transform; } -TVM_FFI_REGISTER_GLOBAL("arith.SolveInequalitiesAsCondition") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - IntConstraints problem; - PartialSolvedInequalities ret_ineq; - if (args.size() == 1) { - problem = args[0].cast(); - ret_ineq = SolveLinearInequalities(problem); - } else if (args.size() == 3) { - problem = IntConstraints(args[0].cast>(), args[1].cast>(), +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed( + "arith.SolveInequalitiesAsCondition", + [](ffi::PackedArgs args, ffi::Any* ret) { + IntConstraints problem; + PartialSolvedInequalities ret_ineq; + if (args.size() == 1) { + problem = args[0].cast(); + ret_ineq = SolveLinearInequalities(problem); + } else if (args.size() == 3) { + problem = IntConstraints(args[0].cast>(), args[1].cast>(), + args[2].cast>()); + ret_ineq = SolveLinearInequalities(problem); + } else { + LOG(FATAL) << "arith.SolveInequalitiesAsCondition expects 1 or 3 arguments, gets " + << args.size(); + } + *ret = AsConditions(problem->variables, ret_ineq.first, ret_ineq.second); + }) + .def_packed("arith.SolveInequalitiesToRange", + [](ffi::PackedArgs args, ffi::Any* ret) { + if (args.size() == 1) { + *ret = SolveInequalitiesToRange(args[0].cast()); + } else if (args.size() == 3) { + auto opt_map = args[1].cast>>(); + IntConstraints problem(args[0].cast>(), opt_map.value_or({}), + args[2].cast>()); + *ret = SolveInequalitiesToRange(problem); + } else { + LOG(FATAL) << "arith.SolveInequalitiesToRange expects 1 or 3 arguments, gets " + << args.size(); + } + }) + .def_packed("arith.SolveInequalitiesDeskewRange", [](ffi::PackedArgs args, ffi::Any* ret) { + if (args.size() == 1) { + *ret = SolveInequalitiesDeskewRange(args[0].cast()); + } else if (args.size() == 3) { + auto opt_map = args[1].cast>>(); + IntConstraints problem(args[0].cast>(), opt_map.value_or({}), args[2].cast>()); - ret_ineq = SolveLinearInequalities(problem); - } else { - LOG(FATAL) << "arith.SolveInequalitiesAsCondition expects 1 or 3 arguments, gets " - << args.size(); - } - *ret = AsConditions(problem->variables, ret_ineq.first, ret_ineq.second); - }); - -TVM_FFI_REGISTER_GLOBAL("arith.SolveInequalitiesToRange") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - if (args.size() == 1) { - *ret = SolveInequalitiesToRange(args[0].cast()); - } else if (args.size() == 3) { - auto opt_map = args[1].cast>>(); - IntConstraints problem(args[0].cast>(), opt_map.value_or({}), - args[2].cast>()); - *ret = SolveInequalitiesToRange(problem); - } else { - LOG(FATAL) << "arith.SolveInequalitiesToRange expects 1 or 3 arguments, gets " - << args.size(); - } - }); - -TVM_FFI_REGISTER_GLOBAL("arith.SolveInequalitiesDeskewRange") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - if (args.size() == 1) { - *ret = SolveInequalitiesDeskewRange(args[0].cast()); - } else if (args.size() == 3) { - auto opt_map = args[1].cast>>(); - IntConstraints problem(args[0].cast>(), opt_map.value_or({}), - args[2].cast>()); - *ret = SolveInequalitiesDeskewRange(problem); - } else { - LOG(FATAL) << "arith.SolveInequalitiesDeskewRange expects 1 or 3 arguments, gets " - << args.size(); - } - }); + *ret = SolveInequalitiesDeskewRange(problem); + } else { + LOG(FATAL) << "arith.SolveInequalitiesDeskewRange expects 1 or 3 arguments, gets " + << args.size(); + } + }); +}); } // namespace arith } // namespace tvm diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc index 32d9a623eafa..d9a972f6ec0a 100644 --- a/src/contrib/msc/core/ir/graph.cc +++ b/src/contrib/msc/core/ir/graph.cc @@ -23,6 +23,8 @@ #include "graph.h" +#include + #include #include #include @@ -1446,243 +1448,196 @@ TVM_REGISTER_NODE_TYPE(MSCGraphNode); TVM_REGISTER_NODE_TYPE(WeightGraphNode); -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensor") - .set_body_typed([](const String& name, const DataType& dtype, const String& layout, - const Array& shape, const String& alias, - const Array& prims) -> MSCTensor { - return MSCTensor(name, dtype, layout, shape, alias, prims); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensorToJson") - .set_body_typed([](const MSCTensor& tensor) -> String { - const auto& tensor_json = tensor->ToJson(); - std::ostringstream os; - dmlc::JSONWriter writer(&os); - tensor_json.Save(&writer); - return os.str(); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensorFromJson") - .set_body_typed([](const String& tensor_json) -> MSCTensor { return MSCTensor(tensor_json); }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJoint") - .set_body_typed([](Integer index, const String& name, const String& shared_ref, - const String& optype, const Map& attrs, - const Array& scope, const Array& parents, - const Array out_indices, const Array& outputs, - const Map& weights) -> MSCJoint { - std::vector> inputs; - for (size_t i = 0; i < parents.size(); i++) { - inputs.push_back(std::make_pair(parents[i], out_indices[i]->value)); - } - return MSCJoint(index->value, name, shared_ref, optype, attrs, scope, inputs, outputs, - weights); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCPrim") - .set_body_typed([](Integer index, const String& name, const String& optype, - const Map& attrs, const Array& parents) -> MSCPrim { - Array b_parents; - for (const auto& p : parents) { - b_parents.push_back(p); - } - return MSCPrim(index->value, name, optype, b_parents, attrs); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.WeightJoint") - .set_body_typed([](Integer index, const String& name, const String& shared_ref, - const String& weight_type, const MSCTensor& weight, - const Array parents, const Map& attrs, - const Array& friends) -> WeightJoint { - Array b_parents, b_friends; - for (const auto& p : parents) { - b_parents.push_back(p); - } - for (const auto& f : friends) { - b_friends.push_back(f); - } - return WeightJoint(index->value, name, shared_ref, weight_type, weight, b_parents, attrs, - b_friends); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.WeightJointSetAttr") - .set_body_typed([](const WeightJoint& node, const String& key, const String& value) { - node->attrs.Set(key, value); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraph") - .set_body_typed([](const String& name, const Array& nodes, - const Array& input_names, const Array& output_names, - const Array& prims) -> MSCGraph { - return MSCGraph(name, nodes, input_names, output_names, prims); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.WeightGraph") - .set_body_typed([](const MSCGraph& graph, const Map>& main_wtypes, - const Map& relation_wtypes) -> WeightGraph { - return WeightGraph(graph, main_wtypes, relation_wtypes); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("msc.core.MSCTensor", + [](const String& name, const DataType& dtype, const String& layout, + const Array& shape, const String& alias, + const Array& prims) -> MSCTensor { + return MSCTensor(name, dtype, layout, shape, alias, prims); + }) + .def("msc.core.MSCTensorToJson", + [](const MSCTensor& tensor) -> String { + const auto& tensor_json = tensor->ToJson(); + std::ostringstream os; + dmlc::JSONWriter writer(&os); + tensor_json.Save(&writer); + return os.str(); + }) + .def("msc.core.MSCTensorFromJson", + [](const String& tensor_json) -> MSCTensor { return MSCTensor(tensor_json); }) + .def("msc.core.MSCJoint", + [](Integer index, const String& name, const String& shared_ref, const String& optype, + const Map& attrs, const Array& scope, + const Array& parents, const Array out_indices, + const Array& outputs, const Map& weights) -> MSCJoint { + std::vector> inputs; + for (size_t i = 0; i < parents.size(); i++) { + inputs.push_back(std::make_pair(parents[i], out_indices[i]->value)); + } + return MSCJoint(index->value, name, shared_ref, optype, attrs, scope, inputs, outputs, + weights); + }) + .def("msc.core.MSCPrim", + [](Integer index, const String& name, const String& optype, + const Map& attrs, const Array& parents) -> MSCPrim { + Array b_parents; + for (const auto& p : parents) { + b_parents.push_back(p); + } + return MSCPrim(index->value, name, optype, b_parents, attrs); + }) + .def("msc.core.WeightJoint", + [](Integer index, const String& name, const String& shared_ref, + const String& weight_type, const MSCTensor& weight, const Array parents, + const Map& attrs, const Array& friends) -> WeightJoint { + Array b_parents, b_friends; + for (const auto& p : parents) { + b_parents.push_back(p); + } + for (const auto& f : friends) { + b_friends.push_back(f); + } + return WeightJoint(index->value, name, shared_ref, weight_type, weight, b_parents, + attrs, b_friends); + }) + .def("msc.core.WeightJointSetAttr", [](const WeightJoint& node, const String& key, + const String& value) { node->attrs.Set(key, value); }) + .def("msc.core.MSCGraph", + [](const String& name, const Array& nodes, const Array& input_names, + const Array& output_names, const Array& prims) -> MSCGraph { + return MSCGraph(name, nodes, input_names, output_names, prims); + }) + .def("msc.core.WeightGraph", + [](const MSCGraph& graph, const Map>& main_wtypes, + const Map& relation_wtypes) -> WeightGraph { + return WeightGraph(graph, main_wtypes, relation_wtypes); + }); +}); // MSC Graph APIS -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphHasNode") - .set_body_typed([](const MSCGraph& graph, const String& name) -> Bool { - return Bool(graph->HasNode(name)); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphFindNode") - .set_body_typed([](const MSCGraph& graph, const String& name) -> MSCJoint { - return graph->FindNode(name); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphFindPrim") - .set_body_typed([](const MSCGraph& graph, const String& name) -> MSCPrim { - return graph->FindPrim(name); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphHasTensor") - .set_body_typed([](const MSCGraph& graph, const String& name) -> Bool { - return Bool(graph->HasTensor(name)); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphFindTensor") - .set_body_typed([](const MSCGraph& graph, const String& name) -> MSCTensor { - return graph->FindTensor(name); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphSetTensorAlias") - .set_body_typed([](const MSCGraph& graph, const MSCTensor& tensor, const String& alias) { - tensor->alias = alias; - graph->tensor_alias.Set(alias, tensor->name); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphFindProducer") - .set_body_typed([](const MSCGraph& graph, const String& name) -> MSCJoint { - return graph->FindProducer(name); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphFindConsumers") - .set_body_typed([](const MSCGraph& graph, const String& name) -> Array { - return graph->FindConsumers(name); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphInputAt") - .set_body_typed([](const MSCGraph& graph, int index) -> MSCTensor { - return graph->InputAt(index); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphOutputAt") - .set_body_typed([](const MSCGraph& graph, int index) -> MSCTensor { - return graph->OutputAt(index); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphGetInputs") - .set_body_typed([](const MSCGraph& graph) -> Array { return graph->GetInputs(); }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphGetOutputs") - .set_body_typed([](const MSCGraph& graph) -> Array { return graph->GetOutputs(); }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphToJson") - .set_body_typed([](const MSCGraph& graph) -> String { - const auto& graph_json = graph->ToJson(); - std::ostringstream os; - dmlc::JSONWriter writer(&os); - graph_json.Save(&writer); - return os.str(); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphFromJson") - .set_body_typed([](const String& graph_json) -> MSCGraph { return MSCGraph(graph_json); }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphToPrototxt") - .set_body_typed([](const MSCGraph& graph) -> String { return graph->ToPrototxt(); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("msc.core.MSCGraphHasNode", + [](const MSCGraph& graph, const String& name) -> Bool { + return Bool(graph->HasNode(name)); + }) + .def("msc.core.MSCGraphFindNode", + [](const MSCGraph& graph, const String& name) -> MSCJoint { + return graph->FindNode(name); + }) + .def("msc.core.MSCGraphFindPrim", + [](const MSCGraph& graph, const String& name) -> MSCPrim { + return graph->FindPrim(name); + }) + .def("msc.core.MSCGraphHasTensor", + [](const MSCGraph& graph, const String& name) -> Bool { + return Bool(graph->HasTensor(name)); + }) + .def("msc.core.MSCGraphFindTensor", + [](const MSCGraph& graph, const String& name) -> MSCTensor { + return graph->FindTensor(name); + }) + .def("msc.core.MSCGraphSetTensorAlias", + [](const MSCGraph& graph, const MSCTensor& tensor, const String& alias) { + tensor->alias = alias; + graph->tensor_alias.Set(alias, tensor->name); + }) + .def("msc.core.MSCGraphFindProducer", + [](const MSCGraph& graph, const String& name) -> MSCJoint { + return graph->FindProducer(name); + }) + .def("msc.core.MSCGraphFindConsumers", + [](const MSCGraph& graph, const String& name) -> Array { + return graph->FindConsumers(name); + }) + .def("msc.core.MSCGraphInputAt", + [](const MSCGraph& graph, int index) -> MSCTensor { return graph->InputAt(index); }) + .def("msc.core.MSCGraphOutputAt", + [](const MSCGraph& graph, int index) -> MSCTensor { return graph->OutputAt(index); }) + .def("msc.core.MSCGraphGetInputs", + [](const MSCGraph& graph) -> Array { return graph->GetInputs(); }) + .def("msc.core.MSCGraphGetOutputs", + [](const MSCGraph& graph) -> Array { return graph->GetOutputs(); }) + .def("msc.core.MSCGraphToJson", + [](const MSCGraph& graph) -> String { + const auto& graph_json = graph->ToJson(); + std::ostringstream os; + dmlc::JSONWriter writer(&os); + graph_json.Save(&writer); + return os.str(); + }) + .def("msc.core.MSCGraphFromJson", + [](const String& graph_json) -> MSCGraph { return MSCGraph(graph_json); }) + .def("msc.core.MSCGraphToPrototxt", + [](const MSCGraph& graph) -> String { return graph->ToPrototxt(); }); +}); // Weight Graph APIS -TVM_FFI_REGISTER_GLOBAL("msc.core.WeightGraphHasNode") - .set_body_typed([](const WeightGraph& graph, const String& name) -> Bool { - return Bool(graph->HasNode(name)); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.WeightGraphFindNode") - .set_body_typed([](const WeightGraph& graph, const String& name) -> WeightJoint { - return graph->FindNode(name); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.WeightGraphToJson") - .set_body_typed([](const WeightGraph& graph) -> String { - const auto& graph_json = graph->ToJson(); - std::ostringstream os; - dmlc::JSONWriter writer(&os); - graph_json.Save(&writer); - return os.str(); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.WeightGraphFromJson") - .set_body_typed([](const String& graph_json) -> WeightGraph { - return WeightGraph(graph_json); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.WeightGraphToPrototxt") - .set_body_typed([](const WeightGraph& graph) -> String { return graph->ToPrototxt(); }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointInputAt") - .set_body_typed([](const MSCJoint& node, int index) -> MSCTensor { - return node->InputAt(index); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointOutputAt") - .set_body_typed([](const MSCJoint& node, int index) -> MSCTensor { - return node->OutputAt(index); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointWeightAt") - .set_body_typed([](const MSCJoint& node, const String& wtype) -> MSCTensor { - return node->WeightAt(wtype); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointGetInputs") - .set_body_typed([](const MSCJoint& node) -> Array { return node->GetInputs(); }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointGetOutputs") - .set_body_typed([](const MSCJoint& node) -> Array { return node->GetOutputs(); }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointGetWeights") - .set_body_typed([](const MSCJoint& node) -> Map { return node->weights; }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointHasAttr") - .set_body_typed([](const MSCJoint& node, const String& key) -> Bool { - return Bool(node->HasAttr(key)); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointGetAttrs") - .set_body_typed([](const MSCJoint& node) -> Map { return node->attrs; }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.WeightJointHasAttr") - .set_body_typed([](const WeightJoint& node, const String& key) -> Bool { - return Bool(node->HasAttr(key)); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.WeightJointGetAttrs") - .set_body_typed([](const WeightJoint& node) -> Map { return node->attrs; }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensorDTypeName") - .set_body_typed([](const MSCTensor& tensor) -> String { return tensor->DTypeName(); }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensorDimAt") - .set_body_typed([](const MSCTensor& tensor, const String& axis) -> Integer { - return tensor->DimAt(axis); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensorGetSize") - .set_body_typed([](const MSCTensor& tensor) -> Integer { return tensor->GetSize(); }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensorSetAlias") - .set_body_typed([](const MSCTensor& tensor, const String& alias) { tensor->alias = alias; }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.PruneWeights") - .set_body_typed([](const MSCGraph& graph, - const Map& pruned_tensors) -> MSCGraph { - return PruneWeights(graph, pruned_tensors); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("msc.core.WeightGraphHasNode", + [](const WeightGraph& graph, const String& name) -> Bool { + return Bool(graph->HasNode(name)); + }) + .def("msc.core.WeightGraphFindNode", + [](const WeightGraph& graph, const String& name) -> WeightJoint { + return graph->FindNode(name); + }) + .def("msc.core.WeightGraphToJson", + [](const WeightGraph& graph) -> String { + const auto& graph_json = graph->ToJson(); + std::ostringstream os; + dmlc::JSONWriter writer(&os); + graph_json.Save(&writer); + return os.str(); + }) + .def("msc.core.WeightGraphFromJson", + [](const String& graph_json) -> WeightGraph { return WeightGraph(graph_json); }) + .def("msc.core.WeightGraphToPrototxt", + [](const WeightGraph& graph) -> String { return graph->ToPrototxt(); }) + .def("msc.core.MSCJointInputAt", + [](const MSCJoint& node, int index) -> MSCTensor { return node->InputAt(index); }) + .def("msc.core.MSCJointOutputAt", + [](const MSCJoint& node, int index) -> MSCTensor { return node->OutputAt(index); }) + .def("msc.core.MSCJointWeightAt", + [](const MSCJoint& node, const String& wtype) -> MSCTensor { + return node->WeightAt(wtype); + }) + .def("msc.core.MSCJointGetInputs", + [](const MSCJoint& node) -> Array { return node->GetInputs(); }) + .def("msc.core.MSCJointGetOutputs", + [](const MSCJoint& node) -> Array { return node->GetOutputs(); }) + .def("msc.core.MSCJointGetWeights", + [](const MSCJoint& node) -> Map { return node->weights; }) + .def("msc.core.MSCJointHasAttr", + [](const MSCJoint& node, const String& key) -> Bool { return Bool(node->HasAttr(key)); }) + .def("msc.core.MSCJointGetAttrs", + [](const MSCJoint& node) -> Map { return node->attrs; }) + .def("msc.core.WeightJointHasAttr", + [](const WeightJoint& node, const String& key) -> Bool { + return Bool(node->HasAttr(key)); + }) + .def("msc.core.WeightJointGetAttrs", + [](const WeightJoint& node) -> Map { return node->attrs; }) + .def("msc.core.MSCTensorDTypeName", + [](const MSCTensor& tensor) -> String { return tensor->DTypeName(); }) + .def("msc.core.MSCTensorDimAt", + [](const MSCTensor& tensor, const String& axis) -> Integer { + return tensor->DimAt(axis); + }) + .def("msc.core.MSCTensorGetSize", + [](const MSCTensor& tensor) -> Integer { return tensor->GetSize(); }) + .def("msc.core.MSCTensorSetAlias", + [](const MSCTensor& tensor, const String& alias) { tensor->alias = alias; }) + .def("msc.core.PruneWeights", + [](const MSCGraph& graph, const Map& pruned_tensors) -> MSCGraph { + return PruneWeights(graph, pruned_tensors); + }); +}); } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index 69903b26a9f0..80997c884141 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -23,6 +23,8 @@ #include "graph_builder.h" +#include + #include #include @@ -834,22 +836,24 @@ void WeightsExtractor::VisitExpr_(const CallNode* op) { } } -TVM_FFI_REGISTER_GLOBAL("msc.core.BuildFromRelax") - .set_body_typed([](const IRModule& module, const String& entry_name, - const String& options) -> MSCGraph { - auto builder = GraphBuilder(module, entry_name, options); - const auto& func_name = - builder.config().byoc_entry.size() > 0 ? String(builder.config().byoc_entry) : entry_name; - const auto& func = Downcast(module->Lookup(func_name)); - return builder.Build(func); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.GetRelaxWeights") - .set_body_typed([](const IRModule& module, - const String& entry_name) -> Map { - const auto& func = Downcast(module->Lookup(entry_name)); - return WeightsExtractor(module).GetWeights(func); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("msc.core.BuildFromRelax", + [](const IRModule& module, const String& entry_name, const String& options) -> MSCGraph { + auto builder = GraphBuilder(module, entry_name, options); + const auto& func_name = builder.config().byoc_entry.size() > 0 + ? String(builder.config().byoc_entry) + : entry_name; + const auto& func = Downcast(module->Lookup(func_name)); + return builder.Build(func); + }) + .def("msc.core.GetRelaxWeights", + [](const IRModule& module, const String& entry_name) -> Map { + const auto& func = Downcast(module->Lookup(entry_name)); + return WeightsExtractor(module).GetWeights(func); + }); +}); } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/core/ir/plugin.cc b/src/contrib/msc/core/ir/plugin.cc index a5df533eb369..6f8515b26c9e 100644 --- a/src/contrib/msc/core/ir/plugin.cc +++ b/src/contrib/msc/core/ir/plugin.cc @@ -23,6 +23,8 @@ #include "plugin.h" +#include + #include #include #include @@ -312,21 +314,16 @@ TVM_FFI_STATIC_INIT_BLOCK({ PluginNode::RegisterReflection(); }); -TVM_FFI_REGISTER_GLOBAL("msc.core.RegisterPlugin") - .set_body_typed([](const String& name, const String& json_str) { - PluginRegistry::Global()->Register(name, json_str); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.ListPluginNames").set_body_typed([]() -> Array { - return ListPluginNames(); -}); - -TVM_FFI_REGISTER_GLOBAL("msc.core.GetPlugin").set_body_typed([](const String& name) -> Plugin { - return GetPlugin(name); -}); - -TVM_FFI_REGISTER_GLOBAL("msc.core.IsPlugin").set_body_typed([](const String& name) -> Bool { - return Bool(IsPlugin(name)); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("msc.core.RegisterPlugin", + [](const String& name, const String& json_str) { + PluginRegistry::Global()->Register(name, json_str); + }) + .def("msc.core.ListPluginNames", []() -> Array { return ListPluginNames(); }) + .def("msc.core.GetPlugin", [](const String& name) -> Plugin { return GetPlugin(name); }) + .def("msc.core.IsPlugin", [](const String& name) -> Bool { return Bool(IsPlugin(name)); }); }); } // namespace msc diff --git a/src/contrib/msc/core/transform/bind_named_params.cc b/src/contrib/msc/core/transform/bind_named_params.cc index 0225ff319097..8136668c26c8 100644 --- a/src/contrib/msc/core/transform/bind_named_params.cc +++ b/src/contrib/msc/core/transform/bind_named_params.cc @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include #include @@ -154,7 +155,10 @@ Pass BindNamedParams(String func_name, Map params) { return CreateModulePass(pass_func, 0, "BindNamedParams", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.BindNamedParams").set_body_typed(BindNamedParams); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.BindNamedParams", BindNamedParams); +}); } // namespace transform diff --git a/src/contrib/msc/core/transform/bind_shape.cc b/src/contrib/msc/core/transform/bind_shape.cc index b554e08ab820..d870568be63b 100644 --- a/src/contrib/msc/core/transform/bind_shape.cc +++ b/src/contrib/msc/core/transform/bind_shape.cc @@ -22,6 +22,7 @@ * \brief Pass for fuse ShapeExpr. */ +#include #include #include #include @@ -132,7 +133,10 @@ Pass BindShape(const String& entry_name) { return CreateModulePass(pass_func, 0, "BindShape", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.BindShape").set_body_typed(BindShape); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.BindShape", BindShape); +}); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/transform/fuse_tuple.cc b/src/contrib/msc/core/transform/fuse_tuple.cc index 1eabf3306f36..79a15f15db79 100644 --- a/src/contrib/msc/core/transform/fuse_tuple.cc +++ b/src/contrib/msc/core/transform/fuse_tuple.cc @@ -22,6 +22,7 @@ * \brief Pass for fuse ShapeExpr. */ +#include #include #include #include @@ -231,7 +232,10 @@ Pass FuseTuple(const String& target, const String& entry_name) { return CreateModulePass(pass_func, 0, "FuseTuple", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.FuseTuple").set_body_typed(FuseTuple); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.FuseTuple", FuseTuple); +}); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/transform/inline_params.cc b/src/contrib/msc/core/transform/inline_params.cc index a91eb590af26..47eeada86eaa 100644 --- a/src/contrib/msc/core/transform/inline_params.cc +++ b/src/contrib/msc/core/transform/inline_params.cc @@ -22,6 +22,7 @@ * \brief Pass for inline Exprs. */ +#include #include #include #include @@ -184,7 +185,10 @@ Pass InlineParams(const String& entry_name) { return CreateModulePass(pass_func, 0, "InlineParams", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.InlineParams").set_body_typed(InlineParams); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.InlineParams", InlineParams); +}); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/transform/set_byoc_attrs.cc b/src/contrib/msc/core/transform/set_byoc_attrs.cc index 4755ebf38960..fc69287d379f 100644 --- a/src/contrib/msc/core/transform/set_byoc_attrs.cc +++ b/src/contrib/msc/core/transform/set_byoc_attrs.cc @@ -22,6 +22,7 @@ * \brief Pass for fuse ShapeExpr. */ +#include #include #include #include @@ -101,7 +102,10 @@ Pass SetBYOCAttrs(const String& target, const String& entry_name) { return CreateModulePass(pass_func, 0, "SetBYOCAttrs", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.SetBYOCAttrs").set_body_typed(SetBYOCAttrs); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.SetBYOCAttrs", SetBYOCAttrs); +}); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc b/src/contrib/msc/core/transform/set_expr_layout.cc index dd87e60e7b80..a8691eadd297 100644 --- a/src/contrib/msc/core/transform/set_expr_layout.cc +++ b/src/contrib/msc/core/transform/set_expr_layout.cc @@ -22,6 +22,7 @@ * \brief Pass for setting layout for expr and constant. */ +#include #include #include #include @@ -1359,7 +1360,10 @@ Pass SetExprLayout(bool allow_missing, const String& entry_name) { return CreateModulePass(pass_func, 0, "SetExprLayout", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.SetExprLayout").set_body_typed(SetExprLayout); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.SetExprLayout", SetExprLayout); +}); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/transform/set_expr_name.cc b/src/contrib/msc/core/transform/set_expr_name.cc index 4d0cc0314e18..9d9b813caf70 100644 --- a/src/contrib/msc/core/transform/set_expr_name.cc +++ b/src/contrib/msc/core/transform/set_expr_name.cc @@ -22,6 +22,7 @@ * \brief Pass for setting name for call and constant. */ +#include #include #include #include @@ -324,7 +325,10 @@ Pass SetRelaxExprName(const String& entry_name, const String& target, return CreateModulePass(pass_func, 0, "SetRelaxExprName", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.SetRelaxExprName").set_body_typed(SetRelaxExprName); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.SetRelaxExprName", SetRelaxExprName); +}); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index d03f3ba82b28..9f50d836d4b4 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -23,6 +23,8 @@ #include "utils.h" +#include + #include #include namespace tvm { @@ -523,28 +525,26 @@ const DataType ExprUtils::GetDataType(const Expr& expr) { return Downcast(GetStructInfo(expr))->dtype; } -TVM_FFI_REGISTER_GLOBAL("msc.core.SpanGetAttr").set_body_typed(SpanUtils::GetAttr); - -TVM_FFI_REGISTER_GLOBAL("msc.core.SpanGetAttrs").set_body_typed(SpanUtils::GetAttrs); - -TVM_FFI_REGISTER_GLOBAL("msc.core.SpanCreateWithAttr") - .set_body_typed([](const String& key, const String& value) -> Span { - return SpanUtils::CreateWithAttr(key, value); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.SpanSetAttr") - .set_body_typed([](const Span& span, const String& key, const String& value) -> Span { - return SpanUtils::SetAttr(span, key, value); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.CompareVersion") - .set_body_typed([](const Array& given_version, - const Array& target_version) -> Integer { - return Integer(CommonUtils::CompareVersion(given_version, target_version)); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.core.ToAttrKey").set_body_typed([](const String& key) -> String { - return CommonUtils::ToAttrKey(key); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("msc.core.SpanGetAttr", SpanUtils::GetAttr) + .def("msc.core.SpanGetAttrs", SpanUtils::GetAttrs) + .def("msc.core.SpanCreateWithAttr", + [](const String& key, const String& value) -> Span { + return SpanUtils::CreateWithAttr(key, value); + }) + .def("msc.core.SpanSetAttr", + [](const Span& span, const String& key, const String& value) -> Span { + return SpanUtils::SetAttr(span, key, value); + }) + .def( + "msc.core.CompareVersion", + [](const Array& given_version, const Array& target_version) -> Integer { + return Integer(CommonUtils::CompareVersion(given_version, target_version)); + }) + .def("msc.core.ToAttrKey", + [](const String& key) -> String { return CommonUtils::ToAttrKey(key); }); }); } // namespace msc diff --git a/src/contrib/msc/framework/tensorflow/codegen.cc b/src/contrib/msc/framework/tensorflow/codegen.cc index 4bceb76d4699..0d55ad997203 100644 --- a/src/contrib/msc/framework/tensorflow/codegen.cc +++ b/src/contrib/msc/framework/tensorflow/codegen.cc @@ -22,6 +22,8 @@ */ #include "codegen.h" +#include + namespace tvm { namespace contrib { namespace msc { @@ -150,13 +152,16 @@ const Array TensorflowCodeGen::GetOpCodes(const MSCJoint& node) { } } -TVM_FFI_REGISTER_GLOBAL("msc.framework.tensorflow.GetTensorflowSources") - .set_body_typed([](const MSCGraph& graph, const String& codegen_config, - const String& print_config) -> Map { - TensorflowCodeGen codegen = TensorflowCodeGen(graph, codegen_config); - codegen.Init(); - return codegen.GetSources(print_config); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("msc.framework.tensorflow.GetTensorflowSources", + [](const MSCGraph& graph, const String& codegen_config, + const String& print_config) -> Map { + TensorflowCodeGen codegen = TensorflowCodeGen(graph, codegen_config); + codegen.Init(); + return codegen.GetSources(print_config); + }); +}); } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc b/src/contrib/msc/framework/tensorrt/codegen.cc index 8b85f2e88f04..5f15d12a4499 100644 --- a/src/contrib/msc/framework/tensorrt/codegen.cc +++ b/src/contrib/msc/framework/tensorrt/codegen.cc @@ -24,6 +24,7 @@ #include "codegen.h" +#include #include #include @@ -574,20 +575,23 @@ const Map TensorRTCodeGen::GetStepCtx() { return step_ctx; } -TVM_FFI_REGISTER_GLOBAL("msc.framework.tensorrt.GetTensorRTSources") - .set_body_typed([](const MSCGraph& graph, const String& codegen_config, - const String& print_config) -> Map { - TensorRTCodeGen codegen = TensorRTCodeGen(graph, codegen_config); - codegen.Init(); - return codegen.GetSources(print_config); - }); - -TVM_FFI_REGISTER_GLOBAL("msc.framework.tensorrt.GetTensorRTRoot").set_body_typed([]() -> String { +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("msc.framework.tensorrt.GetTensorRTSources", + [](const MSCGraph& graph, const String& codegen_config, + const String& print_config) -> Map { + TensorRTCodeGen codegen = TensorRTCodeGen(graph, codegen_config); + codegen.Init(); + return codegen.GetSources(print_config); + }) + .def("msc.framework.tensorrt.GetTensorRTRoot", []() -> String { #ifdef TENSORRT_ROOT_DIR - return TENSORRT_ROOT_DIR; + return TENSORRT_ROOT_DIR; #else return ""; #endif + }); }); /*! @@ -618,7 +622,10 @@ Array MSCTensorRTCompiler(Array functions, return compiled_functions; } -TVM_FFI_REGISTER_GLOBAL("relax.ext.msc_tensorrt").set_body_typed(MSCTensorRTCompiler); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.ext.msc_tensorrt", MSCTensorRTCompiler); +}); } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index 67f453268e2a..a847fe5e40c1 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -913,7 +914,10 @@ Pass TransformTensorRT(const String& config) { return CreateFunctionPass(pass_func, 0, "TransformTensorRT", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.TransformTensorRT").set_body_typed(TransformTensorRT); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.TransformTensorRT", TransformTensorRT); +}); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/framework/torch/codegen.cc b/src/contrib/msc/framework/torch/codegen.cc index 228efa4381ee..ee169c191778 100644 --- a/src/contrib/msc/framework/torch/codegen.cc +++ b/src/contrib/msc/framework/torch/codegen.cc @@ -22,6 +22,8 @@ */ #include "codegen.h" +#include + namespace tvm { namespace contrib { namespace msc { @@ -151,13 +153,16 @@ const Array TorchCodeGen::GetOpCodes(const MSCJoint& node) { } } -TVM_FFI_REGISTER_GLOBAL("msc.framework.torch.GetTorchSources") - .set_body_typed([](const MSCGraph& graph, const String& codegen_config, - const String& print_config) -> Map { - TorchCodeGen codegen = TorchCodeGen(graph, codegen_config); - codegen.Init(); - return codegen.GetSources(print_config); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("msc.framework.torch.GetTorchSources", + [](const MSCGraph& graph, const String& codegen_config, + const String& print_config) -> Map { + TorchCodeGen codegen = TorchCodeGen(graph, codegen_config); + codegen.Init(); + return codegen.GetSources(print_config); + }); +}); } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/framework/tvm/codegen.cc b/src/contrib/msc/framework/tvm/codegen.cc index 53d1bc0562fc..cd80b2e0c724 100644 --- a/src/contrib/msc/framework/tvm/codegen.cc +++ b/src/contrib/msc/framework/tvm/codegen.cc @@ -22,6 +22,8 @@ */ #include "codegen.h" +#include + namespace tvm { namespace contrib { namespace msc { @@ -210,13 +212,16 @@ const Array RelaxCodeGen::GetOpCodes(const MSCJoint& node) { } } -TVM_FFI_REGISTER_GLOBAL("msc.framework.tvm.GetRelaxSources") - .set_body_typed([](const MSCGraph& graph, const String& codegen_config, - const String& print_config) -> Map { - RelaxCodeGen codegen = RelaxCodeGen(graph, codegen_config); - codegen.Init(); - return codegen.GetSources(print_config); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("msc.framework.tvm.GetRelaxSources", + [](const MSCGraph& graph, const String& codegen_config, + const String& print_config) -> Map { + RelaxCodeGen codegen = RelaxCodeGen(graph, codegen_config); + codegen.Init(); + return codegen.GetSources(print_config); + }); +}); } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/plugin/tensorrt_codegen.cc b/src/contrib/msc/plugin/tensorrt_codegen.cc index 02904c3bd9c8..1c8c2edded46 100644 --- a/src/contrib/msc/plugin/tensorrt_codegen.cc +++ b/src/contrib/msc/plugin/tensorrt_codegen.cc @@ -22,6 +22,8 @@ */ #include "tensorrt_codegen.h" +#include + #include namespace tvm { namespace contrib { @@ -883,18 +885,21 @@ void TensorRTPluginCodeGen::CodegenEnqueue(const Plugin& plugin, bool dynamic) { } } -TVM_FFI_REGISTER_GLOBAL("msc.plugin.GetTensorRTPluginSources") - .set_body_typed([](const String& codegen_config, const String& print_config, - const String& codegen_type) -> Map { - TensorRTPluginCodeGen codegen = TensorRTPluginCodeGen(codegen_config); - if (codegen_type == "build") { - return codegen.GetBuildSources(print_config); - } - if (codegen_type == "manager") { - return codegen.GetManagerSources(print_config); - } - return Map(); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("msc.plugin.GetTensorRTPluginSources", + [](const String& codegen_config, const String& print_config, + const String& codegen_type) -> Map { + TensorRTPluginCodeGen codegen = TensorRTPluginCodeGen(codegen_config); + if (codegen_type == "build") { + return codegen.GetBuildSources(print_config); + } + if (codegen_type == "manager") { + return codegen.GetManagerSources(print_config); + } + return Map(); + }); +}); } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/plugin/torch_codegen.cc b/src/contrib/msc/plugin/torch_codegen.cc index 59b99f22c7ce..448e8ce55836 100644 --- a/src/contrib/msc/plugin/torch_codegen.cc +++ b/src/contrib/msc/plugin/torch_codegen.cc @@ -22,6 +22,8 @@ */ #include "torch_codegen.h" +#include + namespace tvm { namespace contrib { namespace msc { @@ -492,18 +494,21 @@ void TorchPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& devi } } -TVM_FFI_REGISTER_GLOBAL("msc.plugin.GetTorchPluginSources") - .set_body_typed([](const String& codegen_config, const String& print_config, - const String& codegen_type) -> Map { - TorchPluginCodeGen codegen = TorchPluginCodeGen(codegen_config); - if (codegen_type == "build") { - return codegen.GetBuildSources(print_config); - } - if (codegen_type == "manager") { - return codegen.GetManagerSources(print_config); - } - return Map(); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("msc.plugin.GetTorchPluginSources", + [](const String& codegen_config, const String& print_config, + const String& codegen_type) -> Map { + TorchPluginCodeGen codegen = TorchPluginCodeGen(codegen_config); + if (codegen_type == "build") { + return codegen.GetBuildSources(print_config); + } + if (codegen_type == "manager") { + return codegen.GetManagerSources(print_config); + } + return Map(); + }); +}); } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/plugin/tvm_codegen.cc b/src/contrib/msc/plugin/tvm_codegen.cc index 610fbc4c3282..b42146b94177 100644 --- a/src/contrib/msc/plugin/tvm_codegen.cc +++ b/src/contrib/msc/plugin/tvm_codegen.cc @@ -22,6 +22,8 @@ */ #include "tvm_codegen.h" +#include + namespace tvm { namespace contrib { namespace msc { @@ -393,18 +395,21 @@ void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& device } } -TVM_FFI_REGISTER_GLOBAL("msc.plugin.GetTVMPluginSources") - .set_body_typed([](const String& codegen_config, const String& print_config, - const String& codegen_type) -> Map { - TVMPluginCodeGen codegen = TVMPluginCodeGen(codegen_config); - if (codegen_type == "build") { - return codegen.GetBuildSources(print_config); - } - if (codegen_type == "manager") { - return codegen.GetManagerSources(print_config); - } - return Map(); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("msc.plugin.GetTVMPluginSources", + [](const String& codegen_config, const String& print_config, + const String& codegen_type) -> Map { + TVMPluginCodeGen codegen = TVMPluginCodeGen(codegen_config); + if (codegen_type == "build") { + return codegen.GetBuildSources(print_config); + } + if (codegen_type == "manager") { + return codegen.GetManagerSources(print_config); + } + return Map(); + }); +}); } // namespace msc } // namespace contrib diff --git a/src/ir/analysis.cc b/src/ir/analysis.cc index 26a348bceee1..d1bd278a2c53 100644 --- a/src/ir/analysis.cc +++ b/src/ir/analysis.cc @@ -21,6 +21,7 @@ * \file src/ir/analysis.cc * \brief Analysis functions that must span multiple IR types */ +#include #include #include "../support/ordered_set.h" @@ -43,7 +44,10 @@ Map> CollectCallMap(const IRModule& mod) { return call_map; } -TVM_FFI_REGISTER_GLOBAL("ir.analysis.CollectCallMap").set_body_typed(CollectCallMap); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("ir.analysis.CollectCallMap", CollectCallMap); +}); } // namespace ir } // namespace tvm diff --git a/src/ir/apply_pass_to_function.cc b/src/ir/apply_pass_to_function.cc index 1b6a2896cccb..61c31c616951 100644 --- a/src/ir/apply_pass_to_function.cc +++ b/src/ir/apply_pass_to_function.cc @@ -22,6 +22,7 @@ * \brief Utility transformation that applies an inner pass to a subset of an IRModule */ #include +#include #include #include #include @@ -129,7 +130,10 @@ Pass ApplyPassToFunction(Pass pass, String func_name_regex, return CreateModulePass(pass_func, 0, pass_name, {}); } -TVM_FFI_REGISTER_GLOBAL("transform.ApplyPassToFunction").set_body_typed(ApplyPassToFunction); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("transform.ApplyPassToFunction", ApplyPassToFunction); +}); } // namespace transform } // namespace tvm diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index ff19cc0c03e9..37c08ccac00e 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -75,8 +75,9 @@ TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode); TVM_FFI_STATIC_INIT_BLOCK({ tvm::ffi::reflection::ObjectDef(); }); -TVM_FFI_REGISTER_GLOBAL("ir.DictAttrsGetDict").set_body_typed([](DictAttrs attrs) { - return attrs->dict; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("ir.DictAttrsGetDict", [](DictAttrs attrs) { return attrs->dict; }); }); } // namespace tvm diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index c75bf0c7361a..7ca61ff20196 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -21,6 +21,7 @@ * \file src/ir/diagnostic.cc * \brief Implementation of DiagnosticContext and friends. */ +#include #include #include @@ -39,10 +40,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ /* Diagnostic */ TVM_REGISTER_NODE_TYPE(DiagnosticNode); -TVM_FFI_REGISTER_GLOBAL("diagnostics.Diagnostic") - .set_body_typed([](int level, Span span, String message) { - return Diagnostic(static_cast(level), span, message); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("diagnostics.Diagnostic", [](int level, Span span, String message) { + return Diagnostic(static_cast(level), span, message); + }); +}); Diagnostic::Diagnostic(DiagnosticLevel level, Span span, const std::string& message) { auto n = make_object(); @@ -112,10 +115,13 @@ TVM_DLL DiagnosticRenderer::DiagnosticRenderer( data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("diagnostics.DiagnosticRenderer") - .set_body_typed([](ffi::TypedFunction renderer) { - return DiagnosticRenderer(renderer); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("diagnostics.DiagnosticRenderer", + [](ffi::TypedFunction renderer) { + return DiagnosticRenderer(renderer); + }); +}); /* Diagnostic Context */ TVM_REGISTER_NODE_TYPE(DiagnosticContextNode); @@ -140,10 +146,12 @@ void DiagnosticContext::Render() { } } -TVM_FFI_REGISTER_GLOBAL("diagnostics.DiagnosticRendererRender") - .set_body_typed([](DiagnosticRenderer renderer, DiagnosticContext ctx) { - renderer.Render(ctx); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "diagnostics.DiagnosticRendererRender", + [](DiagnosticRenderer renderer, DiagnosticContext ctx) { renderer.Render(ctx); }); +}); DiagnosticContext::DiagnosticContext(const IRModule& module, const DiagnosticRenderer& renderer) { CHECK(renderer.defined()) << "can not initialize a diagnostic renderer with a null function"; @@ -153,23 +161,27 @@ DiagnosticContext::DiagnosticContext(const IRModule& module, const DiagnosticRen data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("diagnostics.DiagnosticContext") - .set_body_typed([](const IRModule& module, const DiagnosticRenderer& renderer) { - return DiagnosticContext(module, renderer); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("diagnostics.DiagnosticContext", + [](const IRModule& module, const DiagnosticRenderer& renderer) { + return DiagnosticContext(module, renderer); + }); +}); /*! \brief Emit a diagnostic. */ void DiagnosticContext::Emit(const Diagnostic& diagnostic) { (*this)->diagnostics.push_back(diagnostic); } -TVM_FFI_REGISTER_GLOBAL("diagnostics.Emit") - .set_body_typed([](DiagnosticContext ctx, const Diagnostic& diagnostic) { - return ctx.Emit(diagnostic); - }); - -TVM_FFI_REGISTER_GLOBAL("diagnostics.DiagnosticContextRender") - .set_body_typed([](DiagnosticContext context) { return context.Render(); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("diagnostics.Emit", + [](DiagnosticContext ctx, const Diagnostic& diagnostic) { return ctx.Emit(diagnostic); }) + .def("diagnostics.DiagnosticContextRender", + [](DiagnosticContext context) { return context.Render(); }); +}); /*! \brief Emit a diagnostic. */ void DiagnosticContext::EmitFatal(const Diagnostic& diagnostic) { @@ -201,8 +213,10 @@ DiagnosticContext DiagnosticContext::Default(const IRModule& module) { return DiagnosticContext(module, renderer); } -TVM_FFI_REGISTER_GLOBAL("diagnostics.Default").set_body_typed([](const IRModule& module) { - return DiagnosticContext::Default(module); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("diagnostics.Default", + [](const IRModule& module) { return DiagnosticContext::Default(module); }); }); std::ostream& EmitDiagnosticHeader(std::ostream& out, const Span& span, DiagnosticLevel level, @@ -317,14 +331,13 @@ DiagnosticRenderer TerminalRenderer(std::ostream& out) { }); } -TVM_FFI_REGISTER_GLOBAL(DEFAULT_RENDERER).set_body_typed([]() { - return TerminalRenderer(std::cerr); -}); - -TVM_FFI_REGISTER_GLOBAL("diagnostics.GetRenderer").set_body_typed([]() { return GetRenderer(); }); - -TVM_FFI_REGISTER_GLOBAL("diagnostics.ClearRenderer").set_body_typed([]() { - tvm::ffi::Function::RemoveGlobal(OVERRIDE_RENDERER); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def(DEFAULT_RENDERER, []() { return TerminalRenderer(std::cerr); }) + .def("diagnostics.GetRenderer", []() { return GetRenderer(); }) + .def("diagnostics.ClearRenderer", + []() { tvm::ffi::Function::RemoveGlobal(OVERRIDE_RENDERER); }); }); } // namespace tvm diff --git a/src/ir/env_func.cc b/src/ir/env_func.cc index e95b44700619..2bb2588da2b5 100644 --- a/src/ir/env_func.cc +++ b/src/ir/env_func.cc @@ -21,6 +21,7 @@ * \file env_func.cc */ #include +#include #include #include @@ -49,16 +50,17 @@ ObjectPtr CreateEnvNode(const std::string& name) { EnvFunc EnvFunc::Get(const String& name) { return EnvFunc(CreateEnvNode(name)); } -TVM_FFI_REGISTER_GLOBAL("ir.EnvFuncGet").set_body_typed(EnvFunc::Get); - -TVM_FFI_REGISTER_GLOBAL("ir.EnvFuncCall").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - EnvFunc env = args[0].cast(); - ICHECK_GE(args.size(), 1); - env->func.CallPacked(args.Slice(1), rv); -}); - -TVM_FFI_REGISTER_GLOBAL("ir.EnvFuncGetFunction").set_body_typed([](const EnvFunc& n) { - return n->func; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("ir.EnvFuncGet", EnvFunc::Get) + .def_packed("ir.EnvFuncCall", + [](ffi::PackedArgs args, ffi::Any* rv) { + EnvFunc env = args[0].cast(); + ICHECK_GE(args.size(), 1); + env->func.CallPacked(args.Slice(1), rv); + }) + .def("ir.EnvFuncGetFunction", [](const EnvFunc& n) { return n->func; }); }); TVM_REGISTER_NODE_TYPE(EnvFuncNode) diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 3e2f867e0897..522edcbe181e 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -77,8 +78,11 @@ IntImm::IntImm(DataType dtype, int64_t value, Span span) { data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("ir.IntImm").set_body_typed([](DataType dtype, int64_t value, Span span) { - return IntImm(dtype, value, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("ir.IntImm", [](DataType dtype, int64_t value, Span span) { + return IntImm(dtype, value, span); + }); }); TVM_REGISTER_NODE_TYPE(IntImmNode); @@ -179,8 +183,11 @@ FloatImm::FloatImm(DataType dtype, double value, Span span) { data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("ir.FloatImm").set_body_typed([](DataType dtype, double value, Span span) { - return FloatImm(dtype, value, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("ir.FloatImm", [](DataType dtype, double value, Span span) { + return FloatImm(dtype, value, span); + }); }); TVM_REGISTER_NODE_TYPE(FloatImmNode); @@ -192,16 +199,18 @@ Range Range::FromMinExtent(PrimExpr min, PrimExpr extent, Span span) { return Range(make_object(min, extent, span)); } -TVM_FFI_REGISTER_GLOBAL("ir.Range_from_min_extent").set_body_typed(Range::FromMinExtent); - -TVM_FFI_REGISTER_GLOBAL("ir.Range") - .set_body_typed([](PrimExpr begin, Optional end, Span span) -> Range { - if (end.defined()) { - return Range(begin, end.value(), span); - } else { - return Range(IntImm(begin->dtype, 0), begin, span); - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("ir.Range_from_min_extent", Range::FromMinExtent) + .def("ir.Range", [](PrimExpr begin, Optional end, Span span) -> Range { + if (end.defined()) { + return Range(begin, end.value(), span); + } else { + return Range(IntImm(begin->dtype, 0), begin, span); + } + }); +}); TVM_REGISTER_NODE_TYPE(RangeNode); @@ -214,12 +223,15 @@ GlobalVar::GlobalVar(String name_hint, Span span) { TVM_REGISTER_NODE_TYPE(GlobalVarNode); -TVM_FFI_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name) { return GlobalVar(name); }); - -TVM_FFI_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) { - std::stringstream ss; - ss << ref; - return ss.str(); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("ir.GlobalVar", [](String name) { return GlobalVar(name); }) + .def("ir.DebugPrint", [](ObjectRef ref) { + std::stringstream ss; + ss << ref; + return ss.str(); + }); }); } // namespace tvm diff --git a/src/ir/function.cc b/src/ir/function.cc index 66d66e3c8133..75be4619a537 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -22,6 +22,7 @@ * \brief The function data structure. */ #include +#include #include #include #include @@ -29,51 +30,47 @@ namespace tvm { -TVM_FFI_REGISTER_GLOBAL("ir.BaseFunc_Attrs").set_body_typed([](BaseFunc func) { - return func->attrs; -}); - -TVM_FFI_REGISTER_GLOBAL("ir.BaseFuncCopy").set_body_typed([](BaseFunc func) { return func; }); - -TVM_FFI_REGISTER_GLOBAL("ir.BaseFuncWithAttr") - .set_body_typed([](ffi::RValueRef func_ref, String key, Any value) -> BaseFunc { - BaseFunc func = *std::move(func_ref); - if (func->IsInstance()) { - return WithAttr(Downcast(std::move(func)), key, value); - } else if (func->IsInstance()) { - return WithAttr(Downcast(std::move(func)), key, value); - } else { - LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); - } - }); - -TVM_FFI_REGISTER_GLOBAL("ir.BaseFuncWithAttrs") - .set_body_typed([](ffi::RValueRef func_ref, - Map attr_map) -> BaseFunc { - BaseFunc func = *std::move(func_ref); - if (func->IsInstance()) { - return WithAttrs(Downcast(std::move(func)), attr_map); - } - if (const auto f = tvm::ffi::Function::GetGlobal("relax.FuncWithAttrs")) { - if (auto ret = (*f)(func, attr_map).cast>()) { - return ret.value(); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("ir.BaseFunc_Attrs", [](BaseFunc func) { return func->attrs; }) + .def("ir.BaseFuncCopy", [](BaseFunc func) { return func; }) + .def("ir.BaseFuncWithAttr", + [](ffi::RValueRef func_ref, String key, Any value) -> BaseFunc { + BaseFunc func = *std::move(func_ref); + if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } else if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } else { + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + } + }) + .def("ir.BaseFuncWithAttrs", + [](ffi::RValueRef func_ref, Map attr_map) -> BaseFunc { + BaseFunc func = *std::move(func_ref); + if (func->IsInstance()) { + return WithAttrs(Downcast(std::move(func)), attr_map); + } + if (const auto f = tvm::ffi::Function::GetGlobal("relax.FuncWithAttrs")) { + if (auto ret = (*f)(func, attr_map).cast>()) { + return ret.value(); + } + } + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + TVM_FFI_UNREACHABLE(); + }) + .def("ir.BaseFuncWithoutAttr", [](ffi::RValueRef func_ref, String key) -> BaseFunc { + BaseFunc func = *std::move(func_ref); + if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } else if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } else { + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + TVM_FFI_UNREACHABLE(); } - } - LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); - TVM_FFI_UNREACHABLE(); - }); - -TVM_FFI_REGISTER_GLOBAL("ir.BaseFuncWithoutAttr") - .set_body_typed([](ffi::RValueRef func_ref, String key) -> BaseFunc { - BaseFunc func = *std::move(func_ref); - if (func->IsInstance()) { - return WithoutAttr(Downcast(std::move(func)), key); - } else if (func->IsInstance()) { - return WithoutAttr(Downcast(std::move(func)), key); - } else { - LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); - TVM_FFI_UNREACHABLE(); - } - }); + }); +}); } // namespace tvm diff --git a/src/ir/global_info.cc b/src/ir/global_info.cc index fb04b53964f1..e1c187cf606c 100644 --- a/src/ir/global_info.cc +++ b/src/ir/global_info.cc @@ -22,6 +22,7 @@ * \brief Module global info. */ +#include #include namespace tvm { @@ -31,9 +32,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); TVM_REGISTER_NODE_TYPE(DummyGlobalInfoNode); -TVM_FFI_REGISTER_GLOBAL("ir.DummyGlobalInfo").set_body_typed([]() { - auto n = DummyGlobalInfo(make_object()); - return n; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("ir.DummyGlobalInfo", []() { + auto n = DummyGlobalInfo(make_object()); + return n; + }); }); VDevice::VDevice(Target tgt, int dev_id, MemoryScope mem_scope) { @@ -45,8 +49,10 @@ VDevice::VDevice(Target tgt, int dev_id, MemoryScope mem_scope) { } TVM_REGISTER_NODE_TYPE(VDeviceNode); -TVM_FFI_REGISTER_GLOBAL("ir.VDevice") - .set_body_typed([](Target tgt, int dev_id, MemoryScope mem_scope) { - return VDevice(tgt, dev_id, mem_scope); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("ir.VDevice", [](Target tgt, int dev_id, MemoryScope mem_scope) { + return VDevice(tgt, dev_id, mem_scope); + }); +}); } // namespace tvm diff --git a/src/ir/global_var_supply.cc b/src/ir/global_var_supply.cc index 901202e81e67..b7a00655e0a7 100644 --- a/src/ir/global_var_supply.cc +++ b/src/ir/global_var_supply.cc @@ -24,6 +24,7 @@ #include "tvm/ir/global_var_supply.h" #include +#include #include @@ -95,23 +96,18 @@ GlobalVar GlobalVarSupplyNode::FreshGlobal(String name, bool add_prefix) { TVM_REGISTER_NODE_TYPE(GlobalVarSupplyNode); -TVM_FFI_REGISTER_GLOBAL("ir.GlobalVarSupply_NameSupply") - .set_body_typed([](const NameSupply& name_supply) { return GlobalVarSupply(name_supply); }); - -TVM_FFI_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModule").set_body_typed([](IRModule mod) { - return GlobalVarSupply(std::move(mod)); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("ir.GlobalVarSupply_NameSupply", + [](const NameSupply& name_supply) { return GlobalVarSupply(name_supply); }) + .def("ir.GlobalVarSupply_IRModule", + [](IRModule mod) { return GlobalVarSupply(std::move(mod)); }) + .def("ir.GlobalVarSupply_IRModules", + [](const Array& mods) { return GlobalVarSupply(mods); }) + .def_method("ir.GlobalVarSupply_FreshGlobal", &GlobalVarSupplyNode::FreshGlobal) + .def_method("ir.GlobalVarSupply_UniqueGlobalFor", &GlobalVarSupplyNode::UniqueGlobalFor) + .def_method("ir.GlobalVarSupply_ReserveGlobalVar", &GlobalVarSupplyNode::ReserveGlobalVar); }); -TVM_FFI_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModules") - .set_body_typed([](const Array& mods) { return GlobalVarSupply(mods); }); - -TVM_FFI_REGISTER_GLOBAL("ir.GlobalVarSupply_FreshGlobal") - .set_body_method(&GlobalVarSupplyNode::FreshGlobal); - -TVM_FFI_REGISTER_GLOBAL("ir.GlobalVarSupply_UniqueGlobalFor") - .set_body_method(&GlobalVarSupplyNode::UniqueGlobalFor); - -TVM_FFI_REGISTER_GLOBAL("ir.GlobalVarSupply_ReserveGlobalVar") - .set_body_method(&GlobalVarSupplyNode::ReserveGlobalVar); - } // namespace tvm diff --git a/src/ir/instrument.cc b/src/ir/instrument.cc index cd52e2b88680..03897cb6db5c 100644 --- a/src/ir/instrument.cc +++ b/src/ir/instrument.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -177,16 +178,19 @@ void BasePassInstrumentNode::RunAfterPass(const IRModule& ir_module, TVM_REGISTER_NODE_TYPE(BasePassInstrumentNode); -TVM_FFI_REGISTER_GLOBAL("instrument.PassInstrument") - .set_body_typed( - [](String name, ffi::TypedFunction enter_pass_ctx, - ffi::TypedFunction exit_pass_ctx, - ffi::TypedFunction should_run, - ffi::TypedFunction run_before_pass, - ffi::TypedFunction run_after_pass) { - return BasePassInstrument(name, enter_pass_ctx, exit_pass_ctx, should_run, - run_before_pass, run_after_pass); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "instrument.PassInstrument", + [](String name, ffi::TypedFunction enter_pass_ctx, + ffi::TypedFunction exit_pass_ctx, + ffi::TypedFunction should_run, + ffi::TypedFunction run_before_pass, + ffi::TypedFunction run_after_pass) { + return BasePassInstrument(name, enter_pass_ctx, exit_pass_ctx, should_run, run_before_pass, + run_after_pass); + }); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -310,23 +314,26 @@ String RenderPassProfiles() { return os.str(); } -TVM_FFI_REGISTER_GLOBAL("instrument.RenderTimePassProfiles").set_body_typed(RenderPassProfiles); - -TVM_FFI_REGISTER_GLOBAL("instrument.MakePassTimingInstrument").set_body_typed([]() { - auto run_before_pass = [](const IRModule&, const transform::PassInfo& pass_info) { - PassProfile::EnterPass(pass_info->name); - return true; - }; - - auto run_after_pass = [](const IRModule&, const transform::PassInfo& pass_info) { - PassProfile::ExitPass(); - }; - - auto exit_pass_ctx = []() { PassProfileThreadLocalStore::Get()->root.children.clear(); }; - - return BasePassInstrument("PassTimingInstrument", - /* enter_pass_ctx */ nullptr, exit_pass_ctx, /* should_run */ nullptr, - run_before_pass, run_after_pass); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("instrument.RenderTimePassProfiles", RenderPassProfiles) + .def("instrument.MakePassTimingInstrument", []() { + auto run_before_pass = [](const IRModule&, const transform::PassInfo& pass_info) { + PassProfile::EnterPass(pass_info->name); + return true; + }; + + auto run_after_pass = [](const IRModule&, const transform::PassInfo& pass_info) { + PassProfile::ExitPass(); + }; + + auto exit_pass_ctx = []() { PassProfileThreadLocalStore::Get()->root.children.clear(); }; + + return BasePassInstrument("PassTimingInstrument", + /* enter_pass_ctx */ nullptr, exit_pass_ctx, + /* should_run */ nullptr, run_before_pass, run_after_pass); + }); }); } // namespace instrument diff --git a/src/ir/module.cc b/src/ir/module.cc index 91db645b712a..6b513805e208 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include #include @@ -244,116 +245,93 @@ IRModule IRModule::FromExpr(const RelaxExpr& expr, TVM_REGISTER_NODE_TYPE(IRModuleNode); -TVM_FFI_REGISTER_GLOBAL("ir.IRModule") - .set_body_typed([](tvm::Map funcs, tvm::ObjectRef attrs, - Map> global_infos) { - auto dict_attrs = [&attrs]() { - if (!attrs.defined()) { - return DictAttrs(); - } else if (auto* as_dict_attrs = attrs.as()) { - return GetRef(as_dict_attrs); - } else if (attrs.as()) { - return tvm::DictAttrs(Downcast>(attrs)); - } else { - LOG(FATAL) << "Expected attrs argument to be either DictAttrs or Map"; - } - }(); - - return IRModule(funcs, {}, dict_attrs, global_infos); - }); - -TVM_FFI_REGISTER_GLOBAL("ir.Module_Clone").set_body_typed([](IRModule mod) -> IRModule { - IRModule clone = mod; - clone.CopyOnWrite(); - return clone; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("ir.IRModule", + [](tvm::Map funcs, tvm::ObjectRef attrs, + Map> global_infos) { + auto dict_attrs = [&attrs]() { + if (!attrs.defined()) { + return DictAttrs(); + } else if (auto* as_dict_attrs = attrs.as()) { + return GetRef(as_dict_attrs); + } else if (attrs.as()) { + return tvm::DictAttrs(Downcast>(attrs)); + } else { + LOG(FATAL) + << "Expected attrs argument to be either DictAttrs or Map"; + } + }(); + + return IRModule(funcs, {}, dict_attrs, global_infos); + }) + .def("ir.Module_Clone", + [](IRModule mod) -> IRModule { + IRModule clone = mod; + clone.CopyOnWrite(); + return clone; + }) + .def("ir.Module_Add", + [](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule { + ICHECK(val->IsInstance()); + mod->Add(var, Downcast(val), update); + return mod; + }) + .def("ir.Module_Remove", + [](IRModule mod, Variant var) -> IRModule { + GlobalVar gvar = [&]() { + if (auto opt = var.as()) { + return opt.value(); + } else if (auto opt = var.as()) { + return mod->GetGlobalVar(opt.value()); + } else { + LOG(FATAL) << "InternalError: " + << "Variant didn't contain any of the allowed types"; + } + }(); + mod->Remove(gvar); + return mod; + }) + .def("ir.Module_Contains", + [](IRModule mod, Variant var) -> bool { + if (auto opt = var.as()) { + return mod->functions.count(opt.value()); + } else if (auto opt = var.as()) { + return mod->global_var_map_.count(opt.value()); + } else { + LOG(FATAL) << "InternalError: " + << "Variant didn't contain any of the allowed types"; + } + }) + .def_method("ir.Module_GetGlobalVar", &IRModuleNode::GetGlobalVar) + .def_method("ir.Module_GetGlobalVars", &IRModuleNode::GetGlobalVars) + .def_method("ir.Module_ContainGlobalVar", &IRModuleNode::ContainGlobalVar) + .def("ir.Module_Lookup", [](IRModule mod, GlobalVar var) { return mod->Lookup(var); }) + .def("ir.Module_Lookup_str", [](IRModule mod, String var) { return mod->Lookup(var); }) + .def("ir.Module_FromExpr", &IRModule::FromExpr) + .def("ir.Module_Update", [](IRModule mod, IRModule from) { mod->Update(from); }) + .def("ir.Module_UpdateFunction", + [](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); }) + .def("ir.Module_UpdateGlobalInfo", + [](IRModule mod, String name, Array global_info) { + mod->UpdateGlobalInfo(name, global_info); + }) + .def("ir.Module_GetAttrs", [](IRModule mod) -> ObjectRef { return mod->GetAttrs(); }) + .def("ir.Module_WithAttr", + [](ffi::RValueRef mod, String key, ffi::Any value) -> IRModule { + return WithAttr(*std::move(mod), key, value); + }) + .def("ir.Module_WithoutAttr", + [](ffi::RValueRef mod, String key) -> IRModule { + return WithoutAttr(*std::move(mod), key); + }) + .def("ir.Module_WithAttrs", + [](ffi::RValueRef mod, Map attr_map) -> IRModule { + return WithAttrs(*std::move(mod), attr_map); + }) + .def("ir.Module_GetAttr", + [](IRModule mod, String key) -> ObjectRef { return mod->GetAttr(key); }); }); -TVM_FFI_REGISTER_GLOBAL("ir.Module_Add") - .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule { - ICHECK(val->IsInstance()); - mod->Add(var, Downcast(val), update); - return mod; - }); - -TVM_FFI_REGISTER_GLOBAL("ir.Module_Remove") - .set_body_typed([](IRModule mod, Variant var) -> IRModule { - GlobalVar gvar = [&]() { - if (auto opt = var.as()) { - return opt.value(); - } else if (auto opt = var.as()) { - return mod->GetGlobalVar(opt.value()); - } else { - LOG(FATAL) << "InternalError: " - << "Variant didn't contain any of the allowed types"; - } - }(); - mod->Remove(gvar); - return mod; - }); - -TVM_FFI_REGISTER_GLOBAL("ir.Module_Contains") - .set_body_typed([](IRModule mod, Variant var) -> bool { - if (auto opt = var.as()) { - return mod->functions.count(opt.value()); - } else if (auto opt = var.as()) { - return mod->global_var_map_.count(opt.value()); - } else { - LOG(FATAL) << "InternalError: " - << "Variant didn't contain any of the allowed types"; - } - }); - -TVM_FFI_REGISTER_GLOBAL("ir.Module_GetGlobalVar").set_body_method(&IRModuleNode::GetGlobalVar); - -TVM_FFI_REGISTER_GLOBAL("ir.Module_GetGlobalVars").set_body_method(&IRModuleNode::GetGlobalVars); - -TVM_FFI_REGISTER_GLOBAL("ir.Module_ContainGlobalVar") - .set_body_method(&IRModuleNode::ContainGlobalVar); - -TVM_FFI_REGISTER_GLOBAL("ir.Module_Lookup").set_body_typed([](IRModule mod, GlobalVar var) { - return mod->Lookup(var); -}); - -TVM_FFI_REGISTER_GLOBAL("ir.Module_Lookup_str").set_body_typed([](IRModule mod, String var) { - return mod->Lookup(var); -}); - -TVM_FFI_REGISTER_GLOBAL("ir.Module_FromExpr").set_body_typed(&IRModule::FromExpr); - -TVM_FFI_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) { - mod->Update(from); -}); - -TVM_FFI_REGISTER_GLOBAL("ir.Module_UpdateFunction") - .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); }); - -TVM_FFI_REGISTER_GLOBAL("ir.Module_UpdateGlobalInfo") - .set_body_typed([](IRModule mod, String name, Array global_info) { - mod->UpdateGlobalInfo(name, global_info); - }); - -TVM_FFI_REGISTER_GLOBAL("ir.Module_GetAttrs").set_body_typed([](IRModule mod) -> ObjectRef { - return mod->GetAttrs(); -}); - -TVM_FFI_REGISTER_GLOBAL("ir.Module_WithAttr") - .set_body_typed([](ffi::RValueRef mod, String key, ffi::Any value) -> IRModule { - return WithAttr(*std::move(mod), key, value); - }); - -TVM_FFI_REGISTER_GLOBAL("ir.Module_WithoutAttr") - .set_body_typed([](ffi::RValueRef mod, String key) -> IRModule { - return WithoutAttr(*std::move(mod), key); - }); - -TVM_FFI_REGISTER_GLOBAL("ir.Module_WithAttrs") - .set_body_typed([](ffi::RValueRef mod, Map attr_map) -> IRModule { - return WithAttrs(*std::move(mod), attr_map); - }); - -TVM_FFI_REGISTER_GLOBAL("ir.Module_GetAttr") - .set_body_typed([](IRModule mod, String key) -> ObjectRef { - return mod->GetAttr(key); - }); - } // namespace tvm diff --git a/src/ir/name_supply.cc b/src/ir/name_supply.cc index e73b0e63e3d0..692a2720891f 100644 --- a/src/ir/name_supply.cc +++ b/src/ir/name_supply.cc @@ -24,6 +24,7 @@ #include "tvm/ir/name_supply.h" #include +#include #include @@ -92,15 +93,13 @@ std::string NameSupplyNode::GetUniqueName(std::string name, bool add_underscore) TVM_REGISTER_NODE_TYPE(NameSupplyNode); -TVM_FFI_REGISTER_GLOBAL("ir.NameSupply").set_body_typed([](String prefix) { - return NameSupply(prefix); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("ir.NameSupply", [](String prefix) { return NameSupply(prefix); }) + .def_method("ir.NameSupply_FreshName", &NameSupplyNode::FreshName) + .def_method("ir.NameSupply_ReserveName", &NameSupplyNode::ReserveName) + .def_method("ir.NameSupply_ContainsName", &NameSupplyNode::ContainsName); }); -TVM_FFI_REGISTER_GLOBAL("ir.NameSupply_FreshName").set_body_method(&NameSupplyNode::FreshName); - -TVM_FFI_REGISTER_GLOBAL("ir.NameSupply_ReserveName").set_body_method(&NameSupplyNode::ReserveName); - -TVM_FFI_REGISTER_GLOBAL("ir.NameSupply_ContainsName") - .set_body_method(&NameSupplyNode::ContainsName); - } // namespace tvm diff --git a/src/ir/op.cc b/src/ir/op.cc index 0442a1038b65..a07ddfa51b0b 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -22,6 +22,7 @@ * \brief Primitive operators and intrinsics. */ #include +#include #include #include #include @@ -77,83 +78,79 @@ void OpRegEntry::UpdateAttr(const String& key, ffi::Any value, int plevel) { } // Frontend APIs -TVM_FFI_REGISTER_GLOBAL("ir.ListOpNames").set_body_typed([]() { - return OpRegistry::Global()->ListAllNames(); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("ir.ListOpNames", []() { return OpRegistry::Global()->ListAllNames(); }) + .def("ir.GetOp", [](String name) -> Op { return Op::Get(name); }) + .def("ir.OpGetAttr", + [](Op op, String attr_name) -> ffi::Any { + auto op_map = Op::GetAttrMap(attr_name); + ffi::Any rv; + if (op_map.count(op)) { + rv = op_map[op]; + } + return rv; + }) + .def("ir.OpHasAttr", + [](Op op, String attr_name) -> bool { return Op::HasAttrMap(attr_name); }) + .def("ir.OpSetAttr", + [](Op op, String attr_name, ffi::AnyView value, int plevel) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); + reg.set_attr(attr_name, value, plevel); + }) + .def("ir.OpResetAttr", + [](Op op, String attr_name) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name); + reg.reset_attr(attr_name); + }) + .def("ir.RegisterOp", + [](String op_name, String descr) { + const OpRegEntry* reg = OpRegistry::Global()->Get(op_name); + ICHECK(reg == nullptr) + << "AttributeError: Operator " << op_name << " is registered before"; + auto& op = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); + op.describe(descr); + }) + .def("ir.OpAddArgument", + [](Op op, String name, String type, String description) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); + reg.add_argument(name, type, description); + }) + .def("ir.OpSetSupportLevel", + [](Op op, int level) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); + reg.set_support_level(level); + }) + .def("ir.OpSetNumInputs", + [](Op op, int n) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); + reg.set_num_inputs(n); + }) + .def("ir.OpSetAttrsTypeKey", + [](Op op, String key) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); + reg.set_attrs_type_key(key); + }) + .def("ir.RegisterOpAttr", + [](String op_name, String attr_key, ffi::AnyView value, int plevel) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); + // enable resgiteration and override of certain properties + if (attr_key == "num_inputs" && plevel > 128) { + reg.set_num_inputs(value.cast()); + } else if (attr_key == "attrs_type_key" && plevel > 128) { + LOG(FATAL) << "attrs type key no longer supported"; + } else { + reg.set_attr(attr_key, value, plevel); + } + }) + .def("ir.RegisterOpLowerIntrinsic", + [](String name, ffi::Function f, String target, int plevel) { + tvm::OpRegEntry::RegisterOrGet(name).set_attr( + target + ".FLowerIntrinsic", f, plevel); + }); }); -TVM_FFI_REGISTER_GLOBAL("ir.GetOp").set_body_typed([](String name) -> Op { return Op::Get(name); }); - -TVM_FFI_REGISTER_GLOBAL("ir.OpGetAttr").set_body_typed([](Op op, String attr_name) -> ffi::Any { - auto op_map = Op::GetAttrMap(attr_name); - ffi::Any rv; - if (op_map.count(op)) { - rv = op_map[op]; - } - return rv; -}); - -TVM_FFI_REGISTER_GLOBAL("ir.OpHasAttr").set_body_typed([](Op op, String attr_name) -> bool { - return Op::HasAttrMap(attr_name); -}); - -TVM_FFI_REGISTER_GLOBAL("ir.OpSetAttr") - .set_body_typed([](Op op, String attr_name, ffi::AnyView value, int plevel) { - auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); - reg.set_attr(attr_name, value, plevel); - }); - -TVM_FFI_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name) { - auto& reg = OpRegistry::Global()->RegisterOrGet(op->name); - reg.reset_attr(attr_name); -}); - -TVM_FFI_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name, String descr) { - const OpRegEntry* reg = OpRegistry::Global()->Get(op_name); - ICHECK(reg == nullptr) << "AttributeError: Operator " << op_name << " is registered before"; - auto& op = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); - op.describe(descr); -}); - -TVM_FFI_REGISTER_GLOBAL("ir.OpAddArgument") - .set_body_typed([](Op op, String name, String type, String description) { - auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); - reg.add_argument(name, type, description); - }); - -TVM_FFI_REGISTER_GLOBAL("ir.OpSetSupportLevel").set_body_typed([](Op op, int level) { - auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); - reg.set_support_level(level); -}); - -TVM_FFI_REGISTER_GLOBAL("ir.OpSetNumInputs").set_body_typed([](Op op, int n) { - auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); - reg.set_num_inputs(n); -}); - -TVM_FFI_REGISTER_GLOBAL("ir.OpSetAttrsTypeKey").set_body_typed([](Op op, String key) { - auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); - reg.set_attrs_type_key(key); -}); - -TVM_FFI_REGISTER_GLOBAL("ir.RegisterOpAttr") - .set_body_typed([](String op_name, String attr_key, ffi::AnyView value, int plevel) { - auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); - // enable resgiteration and override of certain properties - if (attr_key == "num_inputs" && plevel > 128) { - reg.set_num_inputs(value.cast()); - } else if (attr_key == "attrs_type_key" && plevel > 128) { - LOG(FATAL) << "attrs type key no longer supported"; - } else { - reg.set_attr(attr_key, value, plevel); - } - }); - -TVM_FFI_REGISTER_GLOBAL("ir.RegisterOpLowerIntrinsic") - .set_body_typed([](String name, ffi::Function f, String target, int plevel) { - tvm::OpRegEntry::RegisterOrGet(name).set_attr(target + ".FLowerIntrinsic", f, - plevel); - }); - ObjectPtr CreateOp(const std::string& name) { // Hack use ffi::Any as exchange auto op = Op::Get(name); diff --git a/src/ir/replace_global_vars.cc b/src/ir/replace_global_vars.cc index 0dca97302470..65fe02e833c3 100644 --- a/src/ir/replace_global_vars.cc +++ b/src/ir/replace_global_vars.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include @@ -62,7 +63,10 @@ IRModule ReplaceGlobalVars(IRModule mod, Map replacements) return mod; } -TVM_FFI_REGISTER_GLOBAL("transform.ReplaceGlobalVars").set_body_typed(ReplaceGlobalVars); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("transform.ReplaceGlobalVars", ReplaceGlobalVars); +}); IRModule ModuleReplaceGlobalVars( IRModule mod, Map, Variant> replacements) { @@ -93,7 +97,10 @@ IRModule ModuleReplaceGlobalVars( return ReplaceGlobalVars(mod, gvar_replacements); } -TVM_FFI_REGISTER_GLOBAL("ir.Module_ReplaceGlobalVars").set_body_typed(ModuleReplaceGlobalVars); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("ir.Module_ReplaceGlobalVars", ModuleReplaceGlobalVars); +}); } // namespace transform } // namespace tvm diff --git a/src/ir/source_map.cc b/src/ir/source_map.cc index 95b9e83f77a3..46801b255902 100644 --- a/src/ir/source_map.cc +++ b/src/ir/source_map.cc @@ -21,6 +21,7 @@ * \brief The implementation of the source map data structure. */ #include +#include #include #include @@ -58,7 +59,10 @@ ObjectPtr GetSourceNameNodeByStr(const std::string& name) { SourceName SourceName::Get(const String& name) { return SourceName(GetSourceNameNode(name)); } -TVM_FFI_REGISTER_GLOBAL("ir.SourceName").set_body_typed(SourceName::Get); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("ir.SourceName", SourceName::Get); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -137,13 +141,14 @@ SequentialSpan::SequentialSpan(std::initializer_list init) { TVM_REGISTER_NODE_TYPE(SequentialSpanNode); -TVM_FFI_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source_name, int line, int end_line, - int column, int end_column) { - return Span(source_name, line, end_line, column, end_column); -}); - -TVM_FFI_REGISTER_GLOBAL("ir.SequentialSpan").set_body_typed([](tvm::Array spans) { - return SequentialSpan(spans); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("ir.Span", + [](SourceName source_name, int line, int end_line, int column, int end_column) { + return Span(source_name, line, end_line, column, end_column); + }) + .def("ir.SequentialSpan", [](tvm::Array spans) { return SequentialSpan(spans); }); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -226,12 +231,14 @@ SourceMap::SourceMap(Map source_map) { void SourceMap::Add(const Source& source) { (*this)->source_map.Set(source->source_name, source); } -TVM_FFI_REGISTER_GLOBAL("SourceMapAdd") - .set_body_typed([](SourceMap map, String name, String content) { - auto src_name = SourceName::Get(name); - Source source(src_name, content); - map.Add(source); - return src_name; - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("SourceMapAdd", [](SourceMap map, String name, String content) { + auto src_name = SourceName::Get(name); + Source source(src_name, content); + map.Add(source); + return src_name; + }); +}); } // namespace tvm diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 9e1c95d7f624..ea49cf5c59be 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -500,14 +501,17 @@ Pass CreateModulePass(std::function pass_func, TVM_REGISTER_NODE_TYPE(PassInfoNode); -TVM_FFI_REGISTER_GLOBAL("transform.PassInfo") - .set_body_typed([](int opt_level, String name, tvm::Array required, bool traceable) { - return PassInfo(opt_level, name, required, traceable); - }); - -TVM_FFI_REGISTER_GLOBAL("transform.Info").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - Pass pass = args[0].cast(); - *ret = pass->Info(); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("transform.PassInfo", + [](int opt_level, String name, tvm::Array required, bool traceable) { + return PassInfo(opt_level, name, required, traceable); + }) + .def_packed("transform.Info", [](ffi::PackedArgs args, ffi::Any* ret) { + Pass pass = args[0].cast(); + *ret = pass->Info(); + }); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -537,18 +541,20 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_REGISTER_NODE_TYPE(ModulePassNode); -TVM_FFI_REGISTER_GLOBAL("transform.MakeModulePass") - .set_body_typed( - [](ffi::TypedFunction, PassContext)> pass_func, - PassInfo pass_info) { - auto wrapped_pass_func = [pass_func](IRModule mod, PassContext ctx) { - return pass_func(ffi::RValueRef(std::move(mod)), ctx); - }; - return ModulePass(wrapped_pass_func, pass_info); - }); - -TVM_FFI_REGISTER_GLOBAL("transform.RunPass") - .set_body_typed([](Pass pass, ffi::RValueRef mod) { return pass(*std::move(mod)); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("transform.MakeModulePass", + [](ffi::TypedFunction, PassContext)> pass_func, + PassInfo pass_info) { + auto wrapped_pass_func = [pass_func](IRModule mod, PassContext ctx) { + return pass_func(ffi::RValueRef(std::move(mod)), ctx); + }; + return ModulePass(wrapped_pass_func, pass_info); + }) + .def("transform.RunPass", + [](Pass pass, ffi::RValueRef mod) { return pass(*std::move(mod)); }); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -560,16 +566,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(SequentialNode); -TVM_FFI_REGISTER_GLOBAL("transform.Sequential") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto passes = args[0].cast>(); - int opt_level = args[1].cast(); - std::string name = args[2].cast(); - auto required = args[3].cast>(); - bool traceable = args[4].cast(); - PassInfo pass_info = PassInfo(opt_level, name, required, /* traceable */ traceable); - *ret = Sequential(passes, pass_info); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("transform.Sequential", [](ffi::PackedArgs args, ffi::Any* ret) { + auto passes = args[0].cast>(); + int opt_level = args[1].cast(); + std::string name = args[2].cast(); + auto required = args[3].cast>(); + bool traceable = args[4].cast(); + PassInfo pass_info = PassInfo(opt_level, name, required, /* traceable */ traceable); + *ret = Sequential(passes, pass_info); + }); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -587,23 +595,26 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(PassContextNode); -TVM_FFI_REGISTER_GLOBAL("transform.PassContext") - .set_body_typed([](int opt_level, Array required, Array disabled, - Array instruments, - Optional> config) { - auto pctx = PassContext::Create(); - pctx->opt_level = opt_level; - - pctx->required_pass = std::move(required); - pctx->disabled_pass = std::move(disabled); - pctx->instruments = std::move(instruments); - - if (config.defined()) { - pctx->config = config.value(); - } - PassConfigManager::Global()->Legalize(&(pctx->config)); - return pctx; - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "transform.PassContext", + [](int opt_level, Array required, Array disabled, + Array instruments, Optional> config) { + auto pctx = PassContext::Create(); + pctx->opt_level = opt_level; + + pctx->required_pass = std::move(required); + pctx->disabled_pass = std::move(disabled); + pctx->instruments = std::move(instruments); + + if (config.defined()) { + pctx->config = config.value(); + } + PassConfigManager::Global()->Legalize(&(pctx->config)); + return pctx; + }); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -626,20 +637,19 @@ class PassContext::Internal { static void ExitScope(PassContext pass_ctx) { pass_ctx.ExitWithScope(); } }; -TVM_FFI_REGISTER_GLOBAL("transform.GetCurrentPassContext").set_body_typed(PassContext::Current); - -TVM_FFI_REGISTER_GLOBAL("transform.EnterPassContext") - .set_body_typed(PassContext::Internal::EnterScope); - -TVM_FFI_REGISTER_GLOBAL("transform.ExitPassContext") - .set_body_typed(PassContext::Internal::ExitScope); - -TVM_FFI_REGISTER_GLOBAL("transform.OverrideInstruments") - .set_body_typed([](PassContext pass_ctx, Array instruments) { - pass_ctx.InstrumentExitPassContext(); - pass_ctx->instruments = instruments; - pass_ctx.InstrumentEnterPassContext(); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("transform.GetCurrentPassContext", PassContext::Current) + .def("transform.EnterPassContext", PassContext::Internal::EnterScope) + .def("transform.ExitPassContext", PassContext::Internal::ExitScope) + .def("transform.OverrideInstruments", + [](PassContext pass_ctx, Array instruments) { + pass_ctx.InstrumentExitPassContext(); + pass_ctx->instruments = instruments; + pass_ctx.InstrumentEnterPassContext(); + }); +}); Pass PrintIR(String header, bool show_meta_data) { auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) { @@ -649,9 +659,12 @@ Pass PrintIR(String header, bool show_meta_data) { return CreateModulePass(pass_func, 0, "PrintIR", {}, /* traceable */ false); } -TVM_FFI_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR); - -TVM_FFI_REGISTER_GLOBAL("transform.ListConfigs").set_body_typed(PassContext::ListConfigs); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("transform.PrintIR", PrintIR) + .def("transform.ListConfigs", PassContext::ListConfigs); +}); } // namespace transform } // namespace tvm diff --git a/src/ir/type.cc b/src/ir/type.cc index 83cbd962404a..37bfd540c5dd 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -22,6 +22,7 @@ * \brief Common type system AST nodes throughout the IR. */ #include +#include #include namespace tvm { @@ -42,8 +43,9 @@ PrimType::PrimType(runtime::DataType dtype, Span span) { TVM_REGISTER_NODE_TYPE(PrimTypeNode); -TVM_FFI_REGISTER_GLOBAL("ir.PrimType").set_body_typed([](runtime::DataType dtype) { - return PrimType(dtype); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("ir.PrimType", [](runtime::DataType dtype) { return PrimType(dtype); }); }); PointerType::PointerType(Type element_type, String storage_scope) { @@ -55,10 +57,12 @@ PointerType::PointerType(Type element_type, String storage_scope) { TVM_REGISTER_NODE_TYPE(PointerTypeNode); -TVM_FFI_REGISTER_GLOBAL("ir.PointerType") - .set_body_typed([](Type element_type, String storage_scope = "") { - return PointerType(element_type, storage_scope); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("ir.PointerType", [](Type element_type, String storage_scope = "") { + return PointerType(element_type, storage_scope); + }); +}); FuncType::FuncType(tvm::Array arg_types, Type ret_type, Span span) { ObjectPtr n = make_object(); @@ -70,10 +74,12 @@ FuncType::FuncType(tvm::Array arg_types, Type ret_type, Span span) { TVM_REGISTER_NODE_TYPE(FuncTypeNode); -TVM_FFI_REGISTER_GLOBAL("ir.FuncType") - .set_body_typed([](tvm::Array arg_types, Type ret_type) { - return FuncType(arg_types, ret_type); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("ir.FuncType", [](tvm::Array arg_types, Type ret_type) { + return FuncType(arg_types, ret_type); + }); +}); TupleType::TupleType(Array fields, Span span) { ObjectPtr n = make_object(); @@ -86,12 +92,11 @@ TupleType TupleType::Empty() { return TupleType(Array()); } TVM_REGISTER_NODE_TYPE(TupleTypeNode); -TVM_FFI_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array fields) { - return TupleType(fields); -}); - -TVM_FFI_REGISTER_GLOBAL("ir.TensorMapType").set_body_typed([](Span span) { - return TensorMapType(span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("ir.TupleType", [](Array fields) { return TupleType(fields); }) + .def("ir.TensorMapType", [](Span span) { return TensorMapType(span); }); }); TensorMapType::TensorMapType(Span span) { diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc index eb1e52a17d2c..34206a82846a 100644 --- a/src/meta_schedule/arg_info.cc +++ b/src/meta_schedule/arg_info.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "./utils.h" namespace tvm { @@ -163,15 +165,17 @@ TVM_FFI_STATIC_INIT_BLOCK({ TensorInfoNode::RegisterReflection(); }); TVM_REGISTER_OBJECT_TYPE(ArgInfoNode); TVM_REGISTER_NODE_TYPE(TensorInfoNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ArgInfoAsJSON").set_body_method(&ArgInfoNode::AsJSON); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ArgInfoFromPrimFunc").set_body_typed(ArgInfo::FromPrimFunc); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ArgInfoFromEntryFunc") - .set_body_typed(ArgInfo::FromEntryFunc); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ArgInfoFromJSON").set_body_typed(ArgInfo::FromJSON); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.TensorInfo") - .set_body_typed([](runtime::DataType dtype, ffi::Shape shape) -> TensorInfo { - return TensorInfo(dtype, shape); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("meta_schedule.ArgInfoAsJSON", &ArgInfoNode::AsJSON) + .def("meta_schedule.ArgInfoFromPrimFunc", ArgInfo::FromPrimFunc) + .def("meta_schedule.ArgInfoFromEntryFunc", ArgInfo::FromEntryFunc) + .def("meta_schedule.ArgInfoFromJSON", ArgInfo::FromJSON) + .def("meta_schedule.TensorInfo", [](runtime::DataType dtype, ffi::Shape shape) -> TensorInfo { + return TensorInfo(dtype, shape); + }); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/builder/builder.cc b/src/meta_schedule/builder/builder.cc index 68c5f4c9c1ab..9623c5002862 100644 --- a/src/meta_schedule/builder/builder.cc +++ b/src/meta_schedule/builder/builder.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -58,21 +60,19 @@ TVM_REGISTER_NODE_TYPE(BuilderResultNode); TVM_REGISTER_OBJECT_TYPE(BuilderNode); TVM_REGISTER_NODE_TYPE(PyBuilderNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.BuilderInput") - .set_body_typed([](IRModule mod, Target target, - Optional> params) -> BuilderInput { - return BuilderInput(mod, target, params); - }); - -TVM_FFI_REGISTER_GLOBAL("meta_schedule.BuilderResult") - .set_body_typed([](Optional artifact_path, - Optional error_msg) -> BuilderResult { - return BuilderResult(artifact_path, error_msg); - }); - -TVM_FFI_REGISTER_GLOBAL("meta_schedule.BuilderBuild").set_body_method(&BuilderNode::Build); - -TVM_FFI_REGISTER_GLOBAL("meta_schedule.BuilderPyBuilder").set_body_typed(Builder::PyBuilder); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("meta_schedule.BuilderInput", + [](IRModule mod, Target target, Optional> params) + -> BuilderInput { return BuilderInput(mod, target, params); }) + .def("meta_schedule.BuilderResult", + [](Optional artifact_path, Optional error_msg) -> BuilderResult { + return BuilderResult(artifact_path, error_msg); + }) + .def_method("meta_schedule.BuilderBuild", &BuilderNode::Build) + .def("meta_schedule.BuilderPyBuilder", Builder::PyBuilder); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/cost_model/cost_model.cc b/src/meta_schedule/cost_model/cost_model.cc index 5c1c7a568580..0ce259bc6c46 100644 --- a/src/meta_schedule/cost_model/cost_model.cc +++ b/src/meta_schedule/cost_model/cost_model.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -71,19 +73,22 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_OBJECT_TYPE(CostModelNode); TVM_REGISTER_NODE_TYPE(PyCostModelNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.CostModelLoad").set_body_method(&CostModelNode::Load); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.CostModelSave").set_body_method(&CostModelNode::Save); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.CostModelUpdate").set_body_method(&CostModelNode::Update); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.CostModelPredict") - .set_body_typed([](CostModel model, // - const TuneContext& context, // - Array candidates, // - void* p_addr) -> void { - std::vector result = model->Predict(context, candidates); - std::copy(result.begin(), result.end(), static_cast(p_addr)); - }); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.CostModelPyCostModel") - .set_body_typed(CostModel::PyCostModel); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("meta_schedule.CostModelLoad", &CostModelNode::Load) + .def_method("meta_schedule.CostModelSave", &CostModelNode::Save) + .def_method("meta_schedule.CostModelUpdate", &CostModelNode::Update) + .def("meta_schedule.CostModelPredict", + [](CostModel model, // + const TuneContext& context, // + Array candidates, // + void* p_addr) -> void { + std::vector result = model->Predict(context, candidates); + std::copy(result.begin(), result.end(), static_cast(p_addr)); + }) + .def("meta_schedule.CostModelPyCostModel", CostModel::PyCostModel); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index a9c04409c530..fe4401fa4c9a 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../module_equality.h" #include "../utils.h" @@ -288,46 +290,36 @@ TVM_REGISTER_NODE_TYPE(WorkloadNode); TVM_REGISTER_NODE_TYPE(TuningRecordNode); TVM_REGISTER_OBJECT_TYPE(DatabaseNode); TVM_REGISTER_NODE_TYPE(PyDatabaseNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.Workload").set_body_typed([](IRModule mod) { - return Workload(mod); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("meta_schedule.Workload", [](IRModule mod) { return Workload(mod); }) + .def_method("meta_schedule.WorkloadAsJSON", &WorkloadNode::AsJSON) + .def("meta_schedule.WorkloadFromJSON", &Workload::FromJSON) + .def("meta_schedule.TuningRecord", + [](tir::Trace trace, Workload workload, Optional> run_secs, + Optional target, Optional> args_info) { + return TuningRecord(trace, workload, run_secs, target, args_info); + }) + .def_method("meta_schedule.TuningRecordAsMeasureCandidate", + &TuningRecordNode::AsMeasureCandidate) + .def_method("meta_schedule.TuningRecordAsJSON", &TuningRecordNode::AsJSON) + .def("meta_schedule.TuningRecordFromJSON", TuningRecord::FromJSON) + .def_method("meta_schedule.DatabaseEnterWithScope", &Database::EnterWithScope) + .def_method("meta_schedule.DatabaseExitWithScope", &Database::ExitWithScope) + .def("meta_schedule.DatabaseCurrent", Database::Current) + .def_method("meta_schedule.DatabaseHasWorkload", &DatabaseNode::HasWorkload) + .def_method("meta_schedule.DatabaseCommitWorkload", &DatabaseNode::CommitWorkload) + .def_method("meta_schedule.DatabaseCommitTuningRecord", &DatabaseNode::CommitTuningRecord) + .def_method("meta_schedule.DatabaseGetTopK", &DatabaseNode::GetTopK) + .def_method("meta_schedule.DatabaseGetAllTuningRecords", &DatabaseNode::GetAllTuningRecords) + .def_method("meta_schedule.DatabaseSize", &DatabaseNode::Size) + .def_method("meta_schedule.DatabaseQueryTuningRecord", &DatabaseNode::QueryTuningRecord) + .def_method("meta_schedule.DatabaseQuerySchedule", &DatabaseNode::QuerySchedule) + .def_method("meta_schedule.DatabaseQueryIRModule", &DatabaseNode::QueryIRModule) + .def_method("meta_schedule.DatabaseDumpPruned", &DatabaseNode::DumpPruned) + .def("meta_schedule.DatabasePyDatabase", Database::PyDatabase); }); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.WorkloadAsJSON").set_body_method(&WorkloadNode::AsJSON); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.WorkloadFromJSON").set_body_typed(&Workload::FromJSON); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuningRecord") - .set_body_typed([](tir::Trace trace, Workload workload, Optional> run_secs, - Optional target, Optional> args_info) { - return TuningRecord(trace, workload, run_secs, target, args_info); - }); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuningRecordAsMeasureCandidate") - .set_body_method(&TuningRecordNode::AsMeasureCandidate); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON") - .set_body_method(&TuningRecordNode::AsJSON); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuningRecordFromJSON") - .set_body_typed(TuningRecord::FromJSON); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseEnterWithScope") - .set_body_method(&Database::EnterWithScope); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseExitWithScope") - .set_body_method(&Database::ExitWithScope); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseCurrent").set_body_typed(Database::Current); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseHasWorkload") - .set_body_method(&DatabaseNode::HasWorkload); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseCommitWorkload") - .set_body_method(&DatabaseNode::CommitWorkload); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseCommitTuningRecord") - .set_body_method(&DatabaseNode::CommitTuningRecord); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseGetTopK").set_body_method(&DatabaseNode::GetTopK); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseGetAllTuningRecords") - .set_body_method(&DatabaseNode::GetAllTuningRecords); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseSize").set_body_method(&DatabaseNode::Size); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseQueryTuningRecord") - .set_body_method(&DatabaseNode::QueryTuningRecord); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseQuerySchedule") - .set_body_method(&DatabaseNode::QuerySchedule); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseQueryIRModule") - .set_body_method(&DatabaseNode::QueryIRModule); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseDumpPruned") - .set_body_method(&DatabaseNode::DumpPruned); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabasePyDatabase").set_body_typed(Database::PyDatabase); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index f8660453e37e..aca011e55b1b 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -217,8 +217,10 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, TVM_FFI_STATIC_INIT_BLOCK({ JSONDatabaseNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(JSONDatabaseNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseJSONDatabase") - .set_body_typed(Database::JSONDatabase); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.DatabaseJSONDatabase", Database::JSONDatabase); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/memory_database.cc b/src/meta_schedule/database/memory_database.cc index eb1b0d19d49f..9889463728f4 100644 --- a/src/meta_schedule/database/memory_database.cc +++ b/src/meta_schedule/database/memory_database.cc @@ -101,8 +101,10 @@ Database Database::MemoryDatabase(String mod_eq_name) { } TVM_REGISTER_NODE_TYPE(MemoryDatabaseNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseMemoryDatabase") - .set_body_typed(Database::MemoryDatabase); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.DatabaseMemoryDatabase", Database::MemoryDatabase); +}); TVM_FFI_STATIC_INIT_BLOCK({ MemoryDatabaseNode::RegisterReflection(); }); diff --git a/src/meta_schedule/database/ordered_union_database.cc b/src/meta_schedule/database/ordered_union_database.cc index 8f8d2370d982..8a1e1f5a71f6 100644 --- a/src/meta_schedule/database/ordered_union_database.cc +++ b/src/meta_schedule/database/ordered_union_database.cc @@ -85,8 +85,11 @@ Database Database::OrderedUnionDatabase(Array databases) { } TVM_REGISTER_NODE_TYPE(OrderedUnionDatabaseNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseOrderedUnionDatabase") - .set_body_typed(Database::OrderedUnionDatabase); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.DatabaseOrderedUnionDatabase", + Database::OrderedUnionDatabase); +}); TVM_FFI_STATIC_INIT_BLOCK({ OrderedUnionDatabaseNode::RegisterReflection(); }); diff --git a/src/meta_schedule/database/schedule_fn_database.cc b/src/meta_schedule/database/schedule_fn_database.cc index 10ac141bc390..242c797a45ad 100644 --- a/src/meta_schedule/database/schedule_fn_database.cc +++ b/src/meta_schedule/database/schedule_fn_database.cc @@ -103,8 +103,10 @@ Database Database::ScheduleFnDatabase(ffi::TypedFunction sc } TVM_REGISTER_NODE_TYPE(ScheduleFnDatabaseNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseScheduleFnDatabase") - .set_body_typed(Database::ScheduleFnDatabase); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.DatabaseScheduleFnDatabase", Database::ScheduleFnDatabase); +}); TVM_FFI_STATIC_INIT_BLOCK({ ScheduleFnDatabaseNode::RegisterReflection(); }); diff --git a/src/meta_schedule/database/union_database.cc b/src/meta_schedule/database/union_database.cc index a275b95721f8..fa6a5d4b3835 100644 --- a/src/meta_schedule/database/union_database.cc +++ b/src/meta_schedule/database/union_database.cc @@ -87,8 +87,10 @@ Database Database::UnionDatabase(Array databases) { } TVM_REGISTER_NODE_TYPE(UnionDatabaseNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseUnionDatabase") - .set_body_typed(Database::UnionDatabase); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.DatabaseUnionDatabase", Database::UnionDatabase); +}); TVM_FFI_STATIC_INIT_BLOCK({ UnionDatabaseNode::RegisterReflection(); }); diff --git a/src/meta_schedule/extracted_task.cc b/src/meta_schedule/extracted_task.cc index da8a61eb8603..21c1e2024a31 100644 --- a/src/meta_schedule/extracted_task.cc +++ b/src/meta_schedule/extracted_task.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include @@ -41,11 +42,14 @@ ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target, TVM_FFI_STATIC_INIT_BLOCK({ ExtractedTaskNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(ExtractedTaskNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ExtractedTask") - .set_body_typed([](String task_name, IRModule mod, Target target, Array dispatched, - int weight) -> ExtractedTask { - return ExtractedTask(task_name, mod, target, dispatched, weight); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.ExtractedTask", + [](String task_name, IRModule mod, Target target, + Array dispatched, int weight) -> ExtractedTask { + return ExtractedTask(task_name, mod, target, dispatched, weight); + }); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/feature_extractor/feature_extractor.cc b/src/meta_schedule/feature_extractor/feature_extractor.cc index eda856d1bdcf..146d32d6a366 100644 --- a/src/meta_schedule/feature_extractor/feature_extractor.cc +++ b/src/meta_schedule/feature_extractor/feature_extractor.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -53,10 +55,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_REGISTER_OBJECT_TYPE(FeatureExtractorNode); TVM_REGISTER_NODE_TYPE(PyFeatureExtractorNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.FeatureExtractorExtractFrom") - .set_body_method(&FeatureExtractorNode::ExtractFrom); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPyFeatureExtractor") - .set_body_typed(FeatureExtractor::PyFeatureExtractor); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("meta_schedule.FeatureExtractorExtractFrom", &FeatureExtractorNode::ExtractFrom) + .def("meta_schedule.FeatureExtractorPyFeatureExtractor", + FeatureExtractor::PyFeatureExtractor); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc index d74f1c369e0d..60f2bee7ea40 100644 --- a/src/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -1449,8 +1449,11 @@ FeatureExtractor FeatureExtractor::PerStoreFeature(int buffers_per_store, TVM_FFI_STATIC_INIT_BLOCK({ PerStoreFeatureNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(PerStoreFeatureNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPerStoreFeature") - .set_body_typed(FeatureExtractor::PerStoreFeature); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.FeatureExtractorPerStoreFeature", + FeatureExtractor::PerStoreFeature); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/measure_callback/add_to_database.cc b/src/meta_schedule/measure_callback/add_to_database.cc index becd9d2110df..04fac7b3ba3b 100644 --- a/src/meta_schedule/measure_callback/add_to_database.cc +++ b/src/meta_schedule/measure_callback/add_to_database.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -65,8 +67,11 @@ MeasureCallback MeasureCallback::AddToDatabase() { } TVM_REGISTER_NODE_TYPE(AddToDatabaseNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCallbackAddToDatabase") - .set_body_typed(MeasureCallback::AddToDatabase); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.MeasureCallbackAddToDatabase", + MeasureCallback::AddToDatabase); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/measure_callback/measure_callback.cc b/src/meta_schedule/measure_callback/measure_callback.cc index 76e5fcf7276c..e1c32f8e0adf 100644 --- a/src/meta_schedule/measure_callback/measure_callback.cc +++ b/src/meta_schedule/measure_callback/measure_callback.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -64,12 +66,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode); TVM_REGISTER_NODE_TYPE(PyMeasureCallbackNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCallbackApply") - .set_body_method(&MeasureCallbackNode::Apply); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCallbackPyMeasureCallback") - .set_body_typed(MeasureCallback::PyMeasureCallback); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCallbackDefault") - .set_body_typed(MeasureCallback::Default); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("meta_schedule.MeasureCallbackApply", &MeasureCallbackNode::Apply) + .def("meta_schedule.MeasureCallbackPyMeasureCallback", MeasureCallback::PyMeasureCallback) + .def("meta_schedule.MeasureCallbackDefault", MeasureCallback::Default); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/measure_callback/remove_build_artifact.cc b/src/meta_schedule/measure_callback/remove_build_artifact.cc index da74e85cac07..c1500308ee3c 100644 --- a/src/meta_schedule/measure_callback/remove_build_artifact.cc +++ b/src/meta_schedule/measure_callback/remove_build_artifact.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -46,8 +48,11 @@ MeasureCallback MeasureCallback::RemoveBuildArtifact() { } TVM_REGISTER_NODE_TYPE(RemoveBuildArtifactNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCallbackRemoveBuildArtifact") - .set_body_typed(MeasureCallback::RemoveBuildArtifact); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.MeasureCallbackRemoveBuildArtifact", + MeasureCallback::RemoveBuildArtifact); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/measure_callback/update_cost_model.cc b/src/meta_schedule/measure_callback/update_cost_model.cc index 1969d7fc83a9..ada68bf29a6e 100644 --- a/src/meta_schedule/measure_callback/update_cost_model.cc +++ b/src/meta_schedule/measure_callback/update_cost_model.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -63,8 +65,11 @@ MeasureCallback MeasureCallback::UpdateCostModel() { } TVM_REGISTER_NODE_TYPE(UpdateCostModelNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCallbackUpdateCostModel") - .set_body_typed(MeasureCallback::UpdateCostModel); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.MeasureCallbackUpdateCostModel", + MeasureCallback::UpdateCostModel); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/mutator/mutate_compute_location.cc b/src/meta_schedule/mutator/mutate_compute_location.cc index c1baa9b26c3f..5442337ac542 100644 --- a/src/meta_schedule/mutator/mutate_compute_location.cc +++ b/src/meta_schedule/mutator/mutate_compute_location.cc @@ -135,8 +135,11 @@ Mutator Mutator::MutateComputeLocation() { TVM_FFI_STATIC_INIT_BLOCK({ MutateComputeLocationNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(MutateComputeLocationNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorMutateComputeLocation") - .set_body_typed(Mutator::MutateComputeLocation); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.MutatorMutateComputeLocation", + Mutator::MutateComputeLocation); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc index 53ba13d13992..a923ddb6fd1a 100644 --- a/src/meta_schedule/mutator/mutate_parallel.cc +++ b/src/meta_schedule/mutator/mutate_parallel.cc @@ -315,8 +315,10 @@ Mutator Mutator::MutateParallel(int64_t max_jobs_per_core) { TVM_FFI_STATIC_INIT_BLOCK({ MutateParallelNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(MutateParallelNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorMutateParallel") - .set_body_typed(Mutator::MutateParallel); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.MutatorMutateParallel", Mutator::MutateParallel); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/mutator/mutate_thread_binding.cc b/src/meta_schedule/mutator/mutate_thread_binding.cc index 27bf5b334ae3..c9419679a48b 100644 --- a/src/meta_schedule/mutator/mutate_thread_binding.cc +++ b/src/meta_schedule/mutator/mutate_thread_binding.cc @@ -173,8 +173,10 @@ Mutator Mutator::MutateThreadBinding() { return Mutator(make_object() TVM_FFI_STATIC_INIT_BLOCK({ MutateUnrollNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(MutateUnrollNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorMutateUnroll").set_body_typed(Mutator::MutateUnroll); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.MutatorMutateUnroll", Mutator::MutateUnroll); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc index 5f2ccf24fe77..b9e300ef22bd 100644 --- a/src/meta_schedule/mutator/mutator.cc +++ b/src/meta_schedule/mutator/mutator.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -93,21 +95,24 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_REGISTER_OBJECT_TYPE(MutatorNode); TVM_REGISTER_NODE_TYPE(PyMutatorNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorInitializeWithTuneContext") - .set_body_method(&MutatorNode::InitializeWithTuneContext); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorApply") - .set_body_typed([](Mutator self, tir::Trace trace, TRandState seed) -> Optional { - TRandState seed_ = (seed != -1) ? seed : support::LinearCongruentialEngine::DeviceRandom(); - return self->Apply(trace, &seed_); - }); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorClone").set_body_method(&MutatorNode::Clone); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorPyMutator").set_body_typed(Mutator::PyMutator); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorDefaultLLVM").set_body_typed(Mutator::DefaultLLVM); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorDefaultCUDA").set_body_typed(Mutator::DefaultCUDA); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorDefaultCUDATensorCore") - .set_body_typed(Mutator::DefaultCUDATensorCore); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorDefaultHexagon") - .set_body_typed(Mutator::DefaultHexagon); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("meta_schedule.MutatorInitializeWithTuneContext", + &MutatorNode::InitializeWithTuneContext) + .def("meta_schedule.MutatorApply", + [](Mutator self, tir::Trace trace, TRandState seed) -> Optional { + TRandState seed_ = + (seed != -1) ? seed : support::LinearCongruentialEngine::DeviceRandom(); + return self->Apply(trace, &seed_); + }) + .def_method("meta_schedule.MutatorClone", &MutatorNode::Clone) + .def("meta_schedule.MutatorPyMutator", Mutator::PyMutator) + .def("meta_schedule.MutatorDefaultLLVM", Mutator::DefaultLLVM) + .def("meta_schedule.MutatorDefaultCUDA", Mutator::DefaultCUDA) + .def("meta_schedule.MutatorDefaultCUDATensorCore", Mutator::DefaultCUDATensorCore) + .def("meta_schedule.MutatorDefaultHexagon", Mutator::DefaultHexagon); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc index 01a75a5bfb36..22c9fd9a8ffa 100644 --- a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc +++ b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -184,8 +186,11 @@ Postproc Postproc::DisallowAsyncStridedMemCopy() { } TVM_REGISTER_NODE_TYPE(DisallowAsyncStridedMemCopyNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocDisallowAsyncStridedMemCopy") - .set_body_typed(Postproc::DisallowAsyncStridedMemCopy); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.PostprocDisallowAsyncStridedMemCopy", + Postproc::DisallowAsyncStridedMemCopy); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/disallow_dynamic_loop.cc b/src/meta_schedule/postproc/disallow_dynamic_loop.cc index fd099ac5dd38..a2021085a5b3 100644 --- a/src/meta_schedule/postproc/disallow_dynamic_loop.cc +++ b/src/meta_schedule/postproc/disallow_dynamic_loop.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -83,8 +85,10 @@ Postproc Postproc::DisallowDynamicLoop() { } TVM_REGISTER_NODE_TYPE(DisallowDynamicLoopNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocDisallowDynamicLoop") - .set_body_typed(Postproc::DisallowDynamicLoop); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.PostprocDisallowDynamicLoop", Postproc::DisallowDynamicLoop); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc index 8434cbf808e8..43a8fc981630 100644 --- a/src/meta_schedule/postproc/postproc.cc +++ b/src/meta_schedule/postproc/postproc.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -117,17 +119,19 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_REGISTER_OBJECT_TYPE(PostprocNode); TVM_REGISTER_NODE_TYPE(PyPostprocNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocInitializeWithTuneContext") - .set_body_method(&PostprocNode::InitializeWithTuneContext); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocApply").set_body_method(&PostprocNode::Apply); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocClone").set_body_method(&PostprocNode::Clone); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocPyPostproc").set_body_typed(Postproc::PyPostproc); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocDefaultLLVM").set_body_typed(Postproc::DefaultLLVM); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocDefaultCUDA").set_body_typed(Postproc::DefaultCUDA); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocDefaultCUDATensorCore") - .set_body_typed(Postproc::DefaultCUDATensorCore); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocDefaultHexagon") - .set_body_typed(Postproc::DefaultHexagon); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("meta_schedule.PostprocInitializeWithTuneContext", + &PostprocNode::InitializeWithTuneContext) + .def_method("meta_schedule.PostprocApply", &PostprocNode::Apply) + .def_method("meta_schedule.PostprocClone", &PostprocNode::Clone) + .def("meta_schedule.PostprocPyPostproc", Postproc::PyPostproc) + .def("meta_schedule.PostprocDefaultLLVM", Postproc::DefaultLLVM) + .def("meta_schedule.PostprocDefaultCUDA", Postproc::DefaultCUDA) + .def("meta_schedule.PostprocDefaultCUDATensorCore", Postproc::DefaultCUDATensorCore) + .def("meta_schedule.PostprocDefaultHexagon", Postproc::DefaultHexagon); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc index 2e0ebf985e11..b76fe1d832e4 100644 --- a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -234,8 +234,11 @@ Postproc Postproc::RewriteCooperativeFetch() { TVM_FFI_STATIC_INIT_BLOCK({ RewriteCooperativeFetchNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(RewriteCooperativeFetchNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteCooperativeFetch") - .set_body_typed(Postproc::RewriteCooperativeFetch); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.PostprocRewriteCooperativeFetch", + Postproc::RewriteCooperativeFetch); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc index 84dc33ec98c8..867b6dfa66e1 100644 --- a/src/meta_schedule/postproc/rewrite_layout.cc +++ b/src/meta_schedule/postproc/rewrite_layout.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include #include @@ -273,8 +275,10 @@ Postproc Postproc::RewriteLayout() { } TVM_REGISTER_NODE_TYPE(RewriteLayoutNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteLayout") - .set_body_typed(Postproc::RewriteLayout); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.PostprocRewriteLayout", Postproc::RewriteLayout); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index 3f665cd8d82a..9cf856a52801 100644 --- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -464,8 +466,11 @@ Postproc Postproc::RewriteParallelVectorizeUnroll() { } TVM_REGISTER_NODE_TYPE(RewriteParallelVectorizeUnrollNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteParallelVectorizeUnroll") - .set_body_typed(Postproc::RewriteParallelVectorizeUnroll); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.PostprocRewriteParallelVectorizeUnroll", + Postproc::RewriteParallelVectorizeUnroll); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc index 0aa3d640cc09..02d6edc1b2e7 100644 --- a/src/meta_schedule/postproc/rewrite_reduction_block.cc +++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc @@ -177,8 +177,11 @@ Postproc Postproc::RewriteReductionBlock() { } TVM_REGISTER_NODE_TYPE(RewriteReductionBlockNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteReductionBlock") - .set_body_typed(Postproc::RewriteReductionBlock); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.PostprocRewriteReductionBlock", + Postproc::RewriteReductionBlock); +}); TVM_FFI_STATIC_INIT_BLOCK({ RewriteReductionBlockNode::RegisterReflection(); }); diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc b/src/meta_schedule/postproc/rewrite_tensorize.cc index 9456defeed55..897ea30876ba 100644 --- a/src/meta_schedule/postproc/rewrite_tensorize.cc +++ b/src/meta_schedule/postproc/rewrite_tensorize.cc @@ -112,8 +112,10 @@ Postproc Postproc::RewriteTensorize(bool vectorize_init_loop) { TVM_FFI_STATIC_INIT_BLOCK({ RewriteTensorizeNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(RewriteTensorizeNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorize") - .set_body_typed(Postproc::RewriteTensorize); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.PostprocRewriteTensorize", Postproc::RewriteTensorize); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc index 29f9280afaba..8b620bfb68c2 100644 --- a/src/meta_schedule/postproc/rewrite_unbound_block.cc +++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc @@ -149,8 +149,10 @@ Postproc Postproc::RewriteUnboundBlock(int max_threadblocks) { TVM_FFI_STATIC_INIT_BLOCK({ RewriteUnboundBlockNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(RewriteUnboundBlockNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteUnboundBlock") - .set_body_typed(Postproc::RewriteUnboundBlock); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.PostprocRewriteUnboundBlock", Postproc::RewriteUnboundBlock); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 5b7b637ef242..390fe407a5d5 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include "../utils.h" @@ -215,8 +216,10 @@ Postproc Postproc::VerifyGPUCode() { } TVM_REGISTER_NODE_TYPE(VerifyGPUCodeNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocVerifyGPUCode") - .set_body_typed(Postproc::VerifyGPUCode); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.PostprocVerifyGPUCode", Postproc::VerifyGPUCode); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/verify_vtcm_limit.cc b/src/meta_schedule/postproc/verify_vtcm_limit.cc index 7da2f8546b9e..4e310e16a2ce 100644 --- a/src/meta_schedule/postproc/verify_vtcm_limit.cc +++ b/src/meta_schedule/postproc/verify_vtcm_limit.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include "../utils.h" @@ -69,8 +70,10 @@ Postproc Postproc::VerifyVTCMLimit() { } TVM_REGISTER_NODE_TYPE(VerifyVTCMLimitNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocVerifyVTCMLimit") - .set_body_typed(Postproc::VerifyVTCMLimit); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.PostprocVerifyVTCMLimit", Postproc::VerifyVTCMLimit); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/profiler.cc b/src/meta_schedule/profiler.cc index d92991fcbc34..8540c8188b14 100644 --- a/src/meta_schedule/profiler.cc +++ b/src/meta_schedule/profiler.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include #include @@ -123,17 +125,17 @@ Optional Profiler::Current() { TVM_FFI_STATIC_INIT_BLOCK({ ProfilerNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(ProfilerNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.Profiler").set_body_typed([]() -> Profiler { - return Profiler(); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("meta_schedule.Profiler", []() -> Profiler { return Profiler(); }) + .def_method("meta_schedule.ProfilerEnterWithScope", &Profiler::EnterWithScope) + .def_method("meta_schedule.ProfilerExitWithScope", &Profiler::ExitWithScope) + .def("meta_schedule.ProfilerCurrent", Profiler::Current) + .def_method("meta_schedule.ProfilerGet", &ProfilerNode::Get) + .def_method("meta_schedule.ProfilerTable", &ProfilerNode::Table) + .def("meta_schedule.ProfilerTimedScope", ProfilerTimedScope); }); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ProfilerEnterWithScope") - .set_body_method(&Profiler::EnterWithScope); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ProfilerExitWithScope") - .set_body_method(&Profiler::ExitWithScope); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ProfilerCurrent").set_body_typed(Profiler::Current); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ProfilerGet").set_body_method(&ProfilerNode::Get); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ProfilerTable").set_body_method(&ProfilerNode::Table); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ProfilerTimedScope").set_body_typed(ProfilerTimedScope); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/runner/runner.cc b/src/meta_schedule/runner/runner.cc index 009a2786a983..088468a201e3 100644 --- a/src/meta_schedule/runner/runner.cc +++ b/src/meta_schedule/runner/runner.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -63,25 +65,26 @@ TVM_REGISTER_NODE_TYPE(RunnerResultNode); TVM_REGISTER_NODE_TYPE(RunnerFutureNode); TVM_REGISTER_OBJECT_TYPE(RunnerNode); TVM_REGISTER_NODE_TYPE(PyRunnerNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerInput") - .set_body_typed([](String artifact_path, String device_type, - Array args_info) -> RunnerInput { - return RunnerInput(artifact_path, device_type, args_info); - }); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerResult") - .set_body_typed([](Optional> run_secs, - Optional error_msg) -> RunnerResult { - return RunnerResult(run_secs, error_msg); - }); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerFuture") - .set_body_typed([](RunnerFuture::FDone f_done, RunnerFuture::FResult f_result) -> RunnerFuture { - return RunnerFuture(f_done, f_result); - }); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerFutureDone").set_body_method(&RunnerFutureNode::Done); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerFutureResult") - .set_body_method(&RunnerFutureNode::Result); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerRun").set_body_method(&RunnerNode::Run); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerPyRunner").set_body_typed(Runner::PyRunner); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("meta_schedule.RunnerInput", + [](String artifact_path, String device_type, Array args_info) -> RunnerInput { + return RunnerInput(artifact_path, device_type, args_info); + }) + .def("meta_schedule.RunnerResult", + [](Optional> run_secs, Optional error_msg) -> RunnerResult { + return RunnerResult(run_secs, error_msg); + }) + .def("meta_schedule.RunnerFuture", + [](RunnerFuture::FDone f_done, RunnerFuture::FResult f_result) -> RunnerFuture { + return RunnerFuture(f_done, f_result); + }) + .def_method("meta_schedule.RunnerFutureDone", &RunnerFutureNode::Done) + .def_method("meta_schedule.RunnerFutureResult", &RunnerFutureNode::Result) + .def_method("meta_schedule.RunnerRun", &RunnerNode::Run) + .def("meta_schedule.RunnerPyRunner", Runner::PyRunner); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule/cpu/winograd.cc b/src/meta_schedule/schedule/cpu/winograd.cc index 4e09fa729b3c..fca4b11078a3 100644 --- a/src/meta_schedule/schedule/cpu/winograd.cc +++ b/src/meta_schedule/schedule/cpu/winograd.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include "../../utils.h" @@ -59,43 +60,44 @@ static Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV block return {t0[0], t1[0], t0[1], t1[1]}; } -TVM_FFI_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nhwc_winograd_data_pack") - .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { - BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); - BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); - ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); - sch->ComputeAt(input_tile, /*loop_rv=*/sch->SampleComputeLocation(input_tile), - /*preserve_unit_loops=*/true); - sch->ComputeAt(data_pad, /*loop_rv=*/sch->SampleComputeLocation(data_pad), - /*preserve_unit_loops=*/true); - return {sch}; - }); - -TVM_FFI_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nhwc_winograd_inverse") - .set_body_typed([](Schedule sch, BlockRV block) -> Array { - GetWinogradProducerAndInlineConst(sch, block); - ScheduleDataPack(sch, block, {2, 3}, {0, 1, 4, 5}); - return {sch}; - }); - -TVM_FFI_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nchw_winograd_data_pack") - .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { - BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); - BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); - ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); - sch->ComputeAt(input_tile, /*loop_rv=*/sch->SampleComputeLocation(input_tile), - /*preserve_unit_loops=*/true); - sch->ComputeAt(data_pad, /*loop_rv=*/sch->SampleComputeLocation(data_pad), - /*preserve_unit_loops=*/true); - return {sch}; - }); - -TVM_FFI_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nchw_winograd_inverse") - .set_body_typed([](Schedule sch, BlockRV block) -> Array { - GetWinogradProducerAndInlineConst(sch, block); - ScheduleDataPack(sch, block, {0, 1}, {2, 3, 4, 5}); - return {sch}; - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("meta_schedule.cpu.conv2d_nhwc_winograd_data_pack", + [](Schedule sch, BlockRV data_pack) -> Array { + BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); + BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); + ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); + sch->ComputeAt(input_tile, /*loop_rv=*/sch->SampleComputeLocation(input_tile), + /*preserve_unit_loops=*/true); + sch->ComputeAt(data_pad, /*loop_rv=*/sch->SampleComputeLocation(data_pad), + /*preserve_unit_loops=*/true); + return {sch}; + }) + .def("meta_schedule.cpu.conv2d_nhwc_winograd_inverse", + [](Schedule sch, BlockRV block) -> Array { + GetWinogradProducerAndInlineConst(sch, block); + ScheduleDataPack(sch, block, {2, 3}, {0, 1, 4, 5}); + return {sch}; + }) + .def("meta_schedule.cpu.conv2d_nchw_winograd_data_pack", + [](Schedule sch, BlockRV data_pack) -> Array { + BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); + BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); + ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); + sch->ComputeAt(input_tile, /*loop_rv=*/sch->SampleComputeLocation(input_tile), + /*preserve_unit_loops=*/true); + sch->ComputeAt(data_pad, /*loop_rv=*/sch->SampleComputeLocation(data_pad), + /*preserve_unit_loops=*/true); + return {sch}; + }) + .def("meta_schedule.cpu.conv2d_nchw_winograd_inverse", + [](Schedule sch, BlockRV block) -> Array { + GetWinogradProducerAndInlineConst(sch, block); + ScheduleDataPack(sch, block, {0, 1}, {2, 3, 4, 5}); + return {sch}; + }); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule/cuda/winograd.cc b/src/meta_schedule/schedule/cuda/winograd.cc index c80141f5288d..871e0565fbaf 100644 --- a/src/meta_schedule/schedule/cuda/winograd.cc +++ b/src/meta_schedule/schedule/cuda/winograd.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include @@ -63,102 +64,104 @@ static Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV block return {t0[0], t1[0], t0[1], t1[1]}; } -TVM_FFI_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nhwc_winograd_data_pack") - .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { - BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); - BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); - Array loops = ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); - { - BlockRV data_pack_local = sch->CacheWrite(data_pack, 0, "local"); - sch->ReverseComputeAt(data_pack_local, loops.back(), /*preserve_unit_loops=*/true); - } - { - sch->ComputeAt(input_tile, /*loop_rv=*/loops.back(), /*preserve_unit_loops=*/true); - sch->SetScope(input_tile, /*buffer_index=*/0, /*storage_scope=*/"local"); - sch->ComputeInline(data_pad); - } - { - int64_t max_threadblocks = 256; - int64_t max_threads_per_block = 1024; - Array loops = sch->GetLoops(data_pack); - ICHECK_EQ(loops.size(), 8); - BindSpatialLoop(sch, sch->Fuse({loops[0], loops[1], loops[2], loops[3]}), max_threadblocks, - max_threads_per_block); - } - return {sch}; - }); - -TVM_FFI_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nhwc_winograd_inverse") - .set_body_typed([](Schedule sch, BlockRV inverse) -> Array { - GetWinogradProducerAndInlineConst(sch, inverse); - ScheduleDataPack(sch, inverse, /*tiled=*/{2, 3}, /*unrolled=*/{0, 1, 4, 5}); - int64_t max_threadblocks = 256; - int64_t max_threads_per_block = 1024; - Array loops = sch->GetLoops(inverse); - ICHECK_EQ(loops.size(), 8); - BindSpatialLoop(sch, sch->Fuse({loops[0], loops[1], loops[2], loops[3]}), max_threadblocks, - max_threads_per_block); - return {sch}; - }); - -TVM_FFI_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_data_pack") - .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { - int64_t max_threadblocks = 256; - int64_t max_threads_per_block = 1024; - BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); - BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); - LoopRV outer{nullptr}; - { - Array loops = sch->GetLoops(data_pack); - ICHECK_EQ(loops.size(), 6); - sch->Reorder({loops[2], loops[3], loops[0], loops[1], loops[4], loops[5]}); - sch->Unroll(loops[0]); - sch->Unroll(loops[1]); - sch->Unroll(loops[4]); - sch->Unroll(loops[5]); - outer = BindSpatialLoop(sch, sch->Fuse({loops[2], loops[3]}), max_threadblocks, - max_threads_per_block, /*get_factor=*/nullptr) - .back(); - } - { - BlockRV data_pack_local = sch->CacheWrite(data_pack, 0, "local"); - sch->ReverseComputeAt(data_pack_local, outer, /*preserve_unit_loops=*/true); - } - { - sch->ComputeAt(input_tile, /*loop_rv=*/outer, /*preserve_unit_loops=*/true); - sch->SetScope(input_tile, /*buffer_index=*/0, /*storage_scope=*/"local"); - sch->ComputeInline(data_pad); - } - return {sch}; - }); - -TVM_FFI_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_inverse") - .set_body_typed([](Schedule sch, BlockRV inverse) -> Array { - GetWinogradProducerAndInlineConst(sch, inverse); - // loops on top of the inverse block: [CO, P, tile_size, tile_size, alpha, alpha] - int64_t tile_size = Downcast(sch->Get(inverse)->writes[0]->buffer->shape[2])->value; - LoopRV outer{nullptr}; - { - BlockRV output = sch->GetConsumers(inverse)[0]; - Array nchw = sch->GetLoops(output); - ICHECK_EQ(nchw.size(), 4); - Array hs = sch->Split(nchw[2], {std::nullopt, Integer(tile_size)}); - Array ws = sch->Split(nchw[3], {std::nullopt, Integer(tile_size)}); - sch->Reorder({hs[0], ws[0], hs[1], ws[1]}); - outer = ws[0]; - } - { - sch->ComputeAt(inverse, /*loop_rv=*/outer, /*preserve_unit_loops=*/true); - sch->SetScope(inverse, /*buffer_index=*/0, /*storage_scope=*/"local"); - Array loops = sch->GetLoops(inverse); - ICHECK_EQ(loops.size(), 10); - sch->Unroll(loops[6]); - sch->Unroll(loops[7]); - sch->Unroll(loops[8]); - sch->Unroll(loops[9]); - } - return {sch}; - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("meta_schedule.cuda.conv2d_nhwc_winograd_data_pack", + [](Schedule sch, BlockRV data_pack) -> Array { + BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); + BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); + Array loops = ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); + { + BlockRV data_pack_local = sch->CacheWrite(data_pack, 0, "local"); + sch->ReverseComputeAt(data_pack_local, loops.back(), /*preserve_unit_loops=*/true); + } + { + sch->ComputeAt(input_tile, /*loop_rv=*/loops.back(), /*preserve_unit_loops=*/true); + sch->SetScope(input_tile, /*buffer_index=*/0, /*storage_scope=*/"local"); + sch->ComputeInline(data_pad); + } + { + int64_t max_threadblocks = 256; + int64_t max_threads_per_block = 1024; + Array loops = sch->GetLoops(data_pack); + ICHECK_EQ(loops.size(), 8); + BindSpatialLoop(sch, sch->Fuse({loops[0], loops[1], loops[2], loops[3]}), + max_threadblocks, max_threads_per_block); + } + return {sch}; + }) + .def("meta_schedule.cuda.conv2d_nhwc_winograd_inverse", + [](Schedule sch, BlockRV inverse) -> Array { + GetWinogradProducerAndInlineConst(sch, inverse); + ScheduleDataPack(sch, inverse, /*tiled=*/{2, 3}, /*unrolled=*/{0, 1, 4, 5}); + int64_t max_threadblocks = 256; + int64_t max_threads_per_block = 1024; + Array loops = sch->GetLoops(inverse); + ICHECK_EQ(loops.size(), 8); + BindSpatialLoop(sch, sch->Fuse({loops[0], loops[1], loops[2], loops[3]}), + max_threadblocks, max_threads_per_block); + return {sch}; + }) + .def("meta_schedule.cuda.conv2d_nchw_winograd_data_pack", + [](Schedule sch, BlockRV data_pack) -> Array { + int64_t max_threadblocks = 256; + int64_t max_threads_per_block = 1024; + BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); + BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); + LoopRV outer{nullptr}; + { + Array loops = sch->GetLoops(data_pack); + ICHECK_EQ(loops.size(), 6); + sch->Reorder({loops[2], loops[3], loops[0], loops[1], loops[4], loops[5]}); + sch->Unroll(loops[0]); + sch->Unroll(loops[1]); + sch->Unroll(loops[4]); + sch->Unroll(loops[5]); + outer = BindSpatialLoop(sch, sch->Fuse({loops[2], loops[3]}), max_threadblocks, + max_threads_per_block, /*get_factor=*/nullptr) + .back(); + } + { + BlockRV data_pack_local = sch->CacheWrite(data_pack, 0, "local"); + sch->ReverseComputeAt(data_pack_local, outer, /*preserve_unit_loops=*/true); + } + { + sch->ComputeAt(input_tile, /*loop_rv=*/outer, /*preserve_unit_loops=*/true); + sch->SetScope(input_tile, /*buffer_index=*/0, /*storage_scope=*/"local"); + sch->ComputeInline(data_pad); + } + return {sch}; + }) + .def("meta_schedule.cuda.conv2d_nchw_winograd_inverse", + [](Schedule sch, BlockRV inverse) -> Array { + GetWinogradProducerAndInlineConst(sch, inverse); + // loops on top of the inverse block: [CO, P, tile_size, tile_size, alpha, alpha] + int64_t tile_size = + Downcast(sch->Get(inverse)->writes[0]->buffer->shape[2])->value; + LoopRV outer{nullptr}; + { + BlockRV output = sch->GetConsumers(inverse)[0]; + Array nchw = sch->GetLoops(output); + ICHECK_EQ(nchw.size(), 4); + Array hs = sch->Split(nchw[2], {std::nullopt, Integer(tile_size)}); + Array ws = sch->Split(nchw[3], {std::nullopt, Integer(tile_size)}); + sch->Reorder({hs[0], ws[0], hs[1], ws[1]}); + outer = ws[0]; + } + { + sch->ComputeAt(inverse, /*loop_rv=*/outer, /*preserve_unit_loops=*/true); + sch->SetScope(inverse, /*buffer_index=*/0, /*storage_scope=*/"local"); + Array loops = sch->GetLoops(inverse); + ICHECK_EQ(loops.size(), 10); + sch->Unroll(loops[6]); + sch->Unroll(loops[7]); + sch->Unroll(loops[8]); + sch->Unroll(loops[9]); + } + return {sch}; + }); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc b/src/meta_schedule/schedule_rule/add_rfactor.cc index 190749c09aa9..e5f68485127f 100644 --- a/src/meta_schedule/schedule_rule/add_rfactor.cc +++ b/src/meta_schedule/schedule_rule/add_rfactor.cc @@ -123,8 +123,10 @@ Array AddRFactorNode::Apply(const tir::Schedule& sch, const tir:: TVM_FFI_STATIC_INIT_BLOCK({ AddRFactorNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(AddRFactorNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAddRFactor") - .set_body_typed(ScheduleRule::AddRFactor); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.ScheduleRuleAddRFactor", ScheduleRule::AddRFactor); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/apply_custom_rule.cc b/src/meta_schedule/schedule_rule/apply_custom_rule.cc index 234b9fd13239..6d8995a3bf10 100644 --- a/src/meta_schedule/schedule_rule/apply_custom_rule.cc +++ b/src/meta_schedule/schedule_rule/apply_custom_rule.cc @@ -93,8 +93,10 @@ bool ScheduleRule::IsApplyCustomRule(const ScheduleRule& rule) { TVM_FFI_STATIC_INIT_BLOCK({ ApplyCustomRuleNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(ApplyCustomRuleNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApplyCustomRule") - .set_body_typed(ScheduleRule::ApplyCustomRule); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.ScheduleRuleApplyCustomRule", ScheduleRule::ApplyCustomRule); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index d2077306f70a..6e15e2a7500f 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -84,8 +84,10 @@ ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, Array thread_ TVM_FFI_STATIC_INIT_BLOCK({ AutoBindNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(AutoBindNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoBind") - .set_body_typed(ScheduleRule::AutoBind); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.ScheduleRuleAutoBind", ScheduleRule::AutoBind); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index eb7b2e6e207d..8f8a186065f6 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -196,8 +196,10 @@ ScheduleRule ScheduleRule::AutoInline(bool into_producer, // TVM_FFI_STATIC_INIT_BLOCK({ AutoInlineNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(AutoInlineNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline") - .set_body_typed(ScheduleRule::AutoInline); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.ScheduleRuleAutoInline", ScheduleRule::AutoInline); +}); /*! \brief Inline blocks that produce a constant scalar. */ class InlineConstantScalarsNode : public ScheduleRuleNode { @@ -243,7 +245,10 @@ ScheduleRule ScheduleRule::InlineConstantScalars() { TVM_FFI_STATIC_INIT_BLOCK({ InlineConstantScalarsNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(InlineConstantScalarsNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInlineConstantScalars") - .set_body_typed(ScheduleRule::InlineConstantScalars); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.ScheduleRuleInlineConstantScalars", + ScheduleRule::InlineConstantScalars); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index df0d190003a5..ac1e43d2b6f2 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -296,8 +296,11 @@ ScheduleRule ScheduleRule::CrossThreadReduction(Array thread_extents) { TVM_FFI_STATIC_INIT_BLOCK({ CrossThreadReductionNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(CrossThreadReductionNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleCrossThreadReduction") - .set_body_typed(ScheduleRule::CrossThreadReduction); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.ScheduleRuleCrossThreadReduction", + ScheduleRule::CrossThreadReduction); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 0d17477e2b94..a642cc62c1ba 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -18,6 +18,7 @@ */ #include "./multi_level_tiling.h" +#include #include #include @@ -407,8 +408,11 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional #include #include @@ -923,8 +924,11 @@ ScheduleRule ScheduleRule::MultiLevelTilingTensorCore( } TVM_REGISTER_NODE_TYPE(MultiLevelTilingTensorCoreNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingTensorCore") - .set_body_typed(ScheduleRule::MultiLevelTilingTensorCore); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.ScheduleRuleMultiLevelTilingTensorCore", + ScheduleRule::MultiLevelTilingTensorCore); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc index 0da8ee35cf76..415636949a8d 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc @@ -17,6 +17,8 @@ * under the License. */ +#include + #include "../../tir/schedule/analysis.h" #include "../../tir/schedule/transform.h" #include "../utils.h" @@ -124,8 +126,11 @@ ScheduleRule ScheduleRule::MultiLevelTilingWideVector(String structure, } TVM_REGISTER_NODE_TYPE(MultiLevelTilingWideVectorNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingWideVector") - .set_body_typed(ScheduleRule::MultiLevelTilingWideVector); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.ScheduleRuleMultiLevelTilingWideVector", + ScheduleRule::MultiLevelTilingWideVector); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc index 731860e8d6f0..df20f39fbc42 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc @@ -17,6 +17,8 @@ * under the License. */ +#include + #include "../../tir/schedule/analysis.h" #include "../../tir/schedule/transform.h" #include "../utils.h" @@ -106,8 +108,11 @@ ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin(String intrin_name, String } TVM_REGISTER_NODE_TYPE(MultiLevelTilingWithIntrinNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingWithIntrin") - .set_body_typed(ScheduleRule::MultiLevelTilingWithIntrin); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.ScheduleRuleMultiLevelTilingWithIntrin", + ScheduleRule::MultiLevelTilingWithIntrin); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc index 48a607f23019..a04d1591b4ee 100644 --- a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc +++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -138,8 +138,11 @@ ScheduleRule ScheduleRule::ParallelizeVectorizeUnroll(int max_jobs_per_core, TVM_FFI_STATIC_INIT_BLOCK({ ParallelizeVectorizeUnrollNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(ParallelizeVectorizeUnrollNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleParallelizeVectorizeUnroll") - .set_body_typed(ScheduleRule::ParallelizeVectorizeUnroll); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.ScheduleRuleParallelizeVectorizeUnroll", + ScheduleRule::ParallelizeVectorizeUnroll); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc index a0e602e6573e..18c01e912548 100644 --- a/src/meta_schedule/schedule_rule/random_compute_location.cc +++ b/src/meta_schedule/schedule_rule/random_compute_location.cc @@ -129,7 +129,10 @@ ScheduleRule ScheduleRule::RandomComputeLocation() { TVM_FFI_STATIC_INIT_BLOCK({ RandomComputeLocationNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(RandomComputeLocationNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleRandomComputeLocation") - .set_body_typed(ScheduleRule::RandomComputeLocation); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.ScheduleRuleRandomComputeLocation", + ScheduleRule::RandomComputeLocation); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index e72be72520e7..0086564ee429 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -407,24 +409,20 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_REGISTER_OBJECT_TYPE(ScheduleRuleNode); TVM_REGISTER_NODE_TYPE(PyScheduleRuleNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInitializeWithTuneContext") - .set_body_method(&ScheduleRuleNode::InitializeWithTuneContext); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApply") - .set_body_method(&ScheduleRuleNode::Apply); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleClone") - .set_body_method(&ScheduleRuleNode::Clone); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRulePyScheduleRule") - .set_body_typed(ScheduleRule::PyScheduleRule); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultLLVM") - .set_body_typed(ScheduleRule::DefaultLLVM); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultCUDA") - .set_body_typed(ScheduleRule::DefaultCUDA); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultCUDATensorCore") - .set_body_typed(ScheduleRule::DefaultCUDATensorCore); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultHexagon") - .set_body_typed(ScheduleRule::DefaultHexagon); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultARM") - .set_body_typed(ScheduleRule::DefaultARM); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("meta_schedule.ScheduleRuleInitializeWithTuneContext", + &ScheduleRuleNode::InitializeWithTuneContext) + .def_method("meta_schedule.ScheduleRuleApply", &ScheduleRuleNode::Apply) + .def_method("meta_schedule.ScheduleRuleClone", &ScheduleRuleNode::Clone) + .def("meta_schedule.ScheduleRulePyScheduleRule", ScheduleRule::PyScheduleRule) + .def("meta_schedule.ScheduleRuleDefaultLLVM", ScheduleRule::DefaultLLVM) + .def("meta_schedule.ScheduleRuleDefaultCUDA", ScheduleRule::DefaultCUDA) + .def("meta_schedule.ScheduleRuleDefaultCUDATensorCore", ScheduleRule::DefaultCUDATensorCore) + .def("meta_schedule.ScheduleRuleDefaultHexagon", ScheduleRule::DefaultHexagon) + .def("meta_schedule.ScheduleRuleDefaultARM", ScheduleRule::DefaultARM); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index 61ad46172f61..b50db43edb24 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -801,12 +801,15 @@ Array EvolutionarySearchEvolveWithCostModel(EvolutionarySearch self, TVM_FFI_STATIC_INIT_BLOCK({ EvolutionarySearchNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(EvolutionarySearchNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearch") - .set_body_typed(SearchStrategy::EvolutionarySearch); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearchSampleInitPopulation") - .set_body_typed(EvolutionarySearchSampleInitPopulation); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearchEvolveWithCostModel") - .set_body_typed(EvolutionarySearchEvolveWithCostModel); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("meta_schedule.SearchStrategyEvolutionarySearch", SearchStrategy::EvolutionarySearch) + .def("meta_schedule.SearchStrategyEvolutionarySearchSampleInitPopulation", + EvolutionarySearchSampleInitPopulation) + .def("meta_schedule.SearchStrategyEvolutionarySearchEvolveWithCostModel", + EvolutionarySearchEvolveWithCostModel); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc index 6bfe2927c6e7..2c53b7b77ad9 100644 --- a/src/meta_schedule/search_strategy/replay_func.cc +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -161,8 +163,10 @@ SearchStrategy SearchStrategy::ReplayFunc() { TVM_FFI_STATIC_INIT_BLOCK({ ReplayFuncNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(ReplayFuncNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayFunc") - .set_body_typed(SearchStrategy::ReplayFunc); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.SearchStrategyReplayFunc", SearchStrategy::ReplayFunc); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index ae55bc58f16e..2166d527f53f 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -189,8 +191,10 @@ SearchStrategy SearchStrategy::ReplayTrace(int max_fail_count) { TVM_FFI_STATIC_INIT_BLOCK({ ReplayTraceNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(ReplayTraceNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayTrace") - .set_body_typed(SearchStrategy::ReplayTrace); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.SearchStrategyReplayTrace", SearchStrategy::ReplayTrace); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc index b1ebfd784951..171def21ee9d 100644 --- a/src/meta_schedule/search_strategy/search_strategy.cc +++ b/src/meta_schedule/search_strategy/search_strategy.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -91,24 +93,24 @@ TVM_REGISTER_NODE_TYPE(MeasureCandidateNode); TVM_REGISTER_OBJECT_TYPE(SearchStrategyNode); TVM_REGISTER_NODE_TYPE(PySearchStrategyNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCandidate") - .set_body_typed([](tir::Schedule sch, Optional> args_info) -> MeasureCandidate { - return MeasureCandidate(sch, args_info.value_or({})); - }); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyPySearchStrategy") - .set_body_typed(SearchStrategy::PySearchStrategy); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyInitializeWithTuneContext") - .set_body_method(&SearchStrategyNode::InitializeWithTuneContext); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyPreTuning") - .set_body_method(&SearchStrategyNode::PreTuning); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyPostTuning") - .set_body_method(&SearchStrategyNode::PostTuning); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyGenerateMeasureCandidates") - .set_body_method(&SearchStrategyNode::GenerateMeasureCandidates); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyNotifyRunnerResults") - .set_body_method(&SearchStrategyNode::NotifyRunnerResults); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyClone") - .set_body_method(&SearchStrategyNode::Clone); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("meta_schedule.MeasureCandidate", + [](tir::Schedule sch, Optional> args_info) -> MeasureCandidate { + return MeasureCandidate(sch, args_info.value_or({})); + }) + .def("meta_schedule.SearchStrategyPySearchStrategy", SearchStrategy::PySearchStrategy) + .def_method("meta_schedule.SearchStrategyInitializeWithTuneContext", + &SearchStrategyNode::InitializeWithTuneContext) + .def_method("meta_schedule.SearchStrategyPreTuning", &SearchStrategyNode::PreTuning) + .def_method("meta_schedule.SearchStrategyPostTuning", &SearchStrategyNode::PostTuning) + .def_method("meta_schedule.SearchStrategyGenerateMeasureCandidates", + &SearchStrategyNode::GenerateMeasureCandidates) + .def_method("meta_schedule.SearchStrategyNotifyRunnerResults", + &SearchStrategyNode::NotifyRunnerResults) + .def_method("meta_schedule.SearchStrategyClone", &SearchStrategyNode::Clone); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index a6baaa78d53d..9702b2c60268 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -118,8 +118,11 @@ SpaceGenerator SpaceGenerator::PostOrderApply(ffi::Function f_block_filter, TVM_FFI_STATIC_INIT_BLOCK({ PostOrderApplyNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(PostOrderApplyNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPostOrderApply") - .set_body_typed(SpaceGenerator::PostOrderApply); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.SpaceGeneratorPostOrderApply", + SpaceGenerator::PostOrderApply); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/space_generator/schedule_fn.cc b/src/meta_schedule/space_generator/schedule_fn.cc index 0b52c58449b4..2562c0d6010d 100644 --- a/src/meta_schedule/space_generator/schedule_fn.cc +++ b/src/meta_schedule/space_generator/schedule_fn.cc @@ -100,8 +100,10 @@ SpaceGenerator SpaceGenerator::ScheduleFn(ffi::Function schedule_fn, TVM_FFI_STATIC_INIT_BLOCK({ ScheduleFnNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(ScheduleFnNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorScheduleFn") - .set_body_typed(SpaceGenerator::ScheduleFn); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.SpaceGeneratorScheduleFn", SpaceGenerator::ScheduleFn); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index bd94d6804f2c..e3b9e8fcb7b3 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../../target/parsers/aprofile.h" #include "../utils.h" @@ -195,14 +197,16 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_REGISTER_OBJECT_TYPE(SpaceGeneratorNode); TVM_REGISTER_NODE_TYPE(PySpaceGeneratorNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorInitializeWithTuneContext") - .set_body_method(&SpaceGeneratorNode::InitializeWithTuneContext); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorGenerateDesignSpace") - .set_body_method(&SpaceGeneratorNode::GenerateDesignSpace); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPySpaceGenerator") - .set_body_typed(SpaceGenerator::PySpaceGenerator); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorClone") - .set_body_method(&SpaceGeneratorNode::Clone); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("meta_schedule.SpaceGeneratorInitializeWithTuneContext", + &SpaceGeneratorNode::InitializeWithTuneContext) + .def_method("meta_schedule.SpaceGeneratorGenerateDesignSpace", + &SpaceGeneratorNode::GenerateDesignSpace) + .def("meta_schedule.SpaceGeneratorPySpaceGenerator", SpaceGenerator::PySpaceGenerator) + .def_method("meta_schedule.SpaceGeneratorClone", &SpaceGeneratorNode::Clone); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/space_generator/space_generator_union.cc b/src/meta_schedule/space_generator/space_generator_union.cc index d1f22e013b0d..8af12e879d75 100644 --- a/src/meta_schedule/space_generator/space_generator_union.cc +++ b/src/meta_schedule/space_generator/space_generator_union.cc @@ -87,8 +87,11 @@ SpaceGenerator SpaceGenerator::SpaceGeneratorUnion(Array space_g TVM_FFI_STATIC_INIT_BLOCK({ SpaceGeneratorUnionNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(SpaceGeneratorUnionNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorSpaceGeneratorUnion") - .set_body_typed(SpaceGenerator::SpaceGeneratorUnion); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.SpaceGeneratorSpaceGeneratorUnion", + SpaceGenerator::SpaceGeneratorUnion); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc index 30bfc611cbbb..baf579e18d23 100644 --- a/src/meta_schedule/task_scheduler/gradient_based.cc +++ b/src/meta_schedule/task_scheduler/gradient_based.cc @@ -147,8 +147,10 @@ TaskScheduler TaskScheduler::GradientBased(ffi::Function logger, double alpha, i TVM_FFI_STATIC_INIT_BLOCK({ GradientBasedNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(GradientBasedNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerGradientBased") - .set_body_typed(TaskScheduler::GradientBased); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.TaskSchedulerGradientBased", TaskScheduler::GradientBased); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc index e8d16c0cf421..9523749ddf0c 100644 --- a/src/meta_schedule/task_scheduler/round_robin.cc +++ b/src/meta_schedule/task_scheduler/round_robin.cc @@ -67,8 +67,10 @@ TaskScheduler TaskScheduler::RoundRobin(ffi::Function logger) { TVM_FFI_STATIC_INIT_BLOCK({ RoundRobinNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(RoundRobinNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerRoundRobin") - .set_body_typed(TaskScheduler::RoundRobin); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.TaskSchedulerRoundRobin", TaskScheduler::RoundRobin); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index a787c4456b82..d2f5f05f22ae 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -370,20 +372,18 @@ void PyTaskSchedulerNode::Tune(Array tasks, Array task_we TVM_REGISTER_NODE_TYPE(TaskRecordNode); TVM_REGISTER_OBJECT_TYPE(TaskSchedulerNode); TVM_REGISTER_NODE_TYPE(PyTaskSchedulerNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPyTaskScheduler") - .set_body_typed(TaskScheduler::PyTaskScheduler); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTune") - .set_body_method(&TaskSchedulerNode::Tune); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerJoinRunningTask") - .set_body_method(&TaskSchedulerNode::JoinRunningTask); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerNextTaskId") - .set_body_method(&TaskSchedulerNode::NextTaskId); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTerminateTask") - .set_body_method(&TaskSchedulerNode::TerminateTask); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTouchTask") - .set_body_method(&TaskSchedulerNode::TouchTask); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPrintTuningStatistics") - .set_body_method(&TaskSchedulerNode::PrintTuningStatistics); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("meta_schedule.TaskSchedulerPyTaskScheduler", TaskScheduler::PyTaskScheduler) + .def_method("meta_schedule.TaskSchedulerTune", &TaskSchedulerNode::Tune) + .def_method("meta_schedule.TaskSchedulerJoinRunningTask", &TaskSchedulerNode::JoinRunningTask) + .def_method("meta_schedule.TaskSchedulerNextTaskId", &TaskSchedulerNode::NextTaskId) + .def_method("meta_schedule.TaskSchedulerTerminateTask", &TaskSchedulerNode::TerminateTask) + .def_method("meta_schedule.TaskSchedulerTouchTask", &TaskSchedulerNode::TouchTask) + .def_method("meta_schedule.TaskSchedulerPrintTuningStatistics", + &TaskSchedulerNode::PrintTuningStatistics); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc index 9d22554d912f..3608d25ceaaa 100644 --- a/src/meta_schedule/trace_apply.cc +++ b/src/meta_schedule/trace_apply.cc @@ -18,6 +18,7 @@ */ #include "trace_apply.h" +#include #include #include @@ -254,8 +255,10 @@ void ScheduleUsingAnchorTrace(Schedule sch, const Trace& anchor_trace, const tvm } } -TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleUsingAnchorTrace") - .set_body_typed(ScheduleUsingAnchorTrace); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("meta_schedule.ScheduleUsingAnchorTrace", ScheduleUsingAnchorTrace); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 179a7ac1d6ff..939b90c93bc2 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include #include "./utils.h" @@ -65,19 +67,21 @@ void TuneContextNode::Initialize() { TVM_FFI_STATIC_INIT_BLOCK({ TuneContextNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(TuneContextNode); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuneContext") - .set_body_typed([](Optional mod, Optional target, - Optional space_generator, - Optional search_strategy, Optional task_name, - int num_threads, TRandState rand_state, - ffi::Function logger) -> TuneContext { - return TuneContext(mod, target, space_generator, search_strategy, task_name, num_threads, - rand_state, logger); - }); -TVM_FFI_REGISTER_GLOBAL("meta_schedule._SHash2Hex").set_body_typed(SHash2Hex); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuneContextInitialize") - .set_body_method(&TuneContextNode::Initialize); -TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuneContextClone").set_body_method(&TuneContextNode::Clone); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("meta_schedule.TuneContext", + [](Optional mod, Optional target, + Optional space_generator, Optional search_strategy, + Optional task_name, int num_threads, TRandState rand_state, + ffi::Function logger) -> TuneContext { + return TuneContext(mod, target, space_generator, search_strategy, task_name, + num_threads, rand_state, logger); + }) + .def("meta_schedule._SHash2Hex", SHash2Hex) + .def_method("meta_schedule.TuneContextInitialize", &TuneContextNode::Initialize) + .def_method("meta_schedule.TuneContextClone", &TuneContextNode::Clone); +}); } // namespace meta_schedule } // namespace tvm diff --git a/src/node/object_path.cc b/src/node/object_path.cc index a99835ea17ad..7edd690b7bc2 100644 --- a/src/node/object_path.cc +++ b/src/node/object_path.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -40,13 +41,19 @@ Optional ObjectPathNode::GetParent() const { return Downcast>(parent_); } -TVM_FFI_REGISTER_GLOBAL("node.ObjectPathGetParent").set_body_method(&ObjectPathNode::GetParent); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_method("node.ObjectPathGetParent", &ObjectPathNode::GetParent); +}); // --- Length --- int32_t ObjectPathNode::Length() const { return length_; } -TVM_FFI_REGISTER_GLOBAL("node.ObjectPathLength").set_body_method(&ObjectPathNode::Length); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_method("node.ObjectPathLength", &ObjectPathNode::Length); +}); // --- GetPrefix --- @@ -63,7 +70,10 @@ ObjectPath ObjectPathNode::GetPrefix(int32_t length) const { return GetRef(node); } -TVM_FFI_REGISTER_GLOBAL("node.ObjectPathGetPrefix").set_body_method(&ObjectPathNode::GetPrefix); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_method("node.ObjectPathGetPrefix", &ObjectPathNode::GetPrefix); +}); // --- IsPrefixOf --- @@ -75,7 +85,10 @@ bool ObjectPathNode::IsPrefixOf(const ObjectPath& other) const { return this->PathsEqual(other->GetPrefix(this_len)); } -TVM_FFI_REGISTER_GLOBAL("node.ObjectPathIsPrefixOf").set_body_method(&ObjectPathNode::IsPrefixOf); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_method("node.ObjectPathIsPrefixOf", &ObjectPathNode::IsPrefixOf); +}); // --- Attr --- @@ -95,10 +108,13 @@ ObjectPath ObjectPathNode::Attr(Optional attr_key) const { } } -TVM_FFI_REGISTER_GLOBAL("node.ObjectPathAttr") - .set_body_typed([](const ObjectPath& object_path, Optional attr_key) { - return object_path->Attr(attr_key); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("node.ObjectPathAttr", + [](const ObjectPath& object_path, Optional attr_key) { + return object_path->Attr(attr_key); + }); +}); // --- ArrayIndex --- @@ -106,7 +122,10 @@ ObjectPath ObjectPathNode::ArrayIndex(int32_t index) const { return ObjectPath(make_object(this, index)); } -TVM_FFI_REGISTER_GLOBAL("node.ObjectPathArrayIndex").set_body_method(&ObjectPathNode::ArrayIndex); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_method("node.ObjectPathArrayIndex", &ObjectPathNode::ArrayIndex); +}); // --- MissingArrayElement --- @@ -114,8 +133,11 @@ ObjectPath ObjectPathNode::MissingArrayElement(int32_t index) const { return ObjectPath(make_object(this, index)); } -TVM_FFI_REGISTER_GLOBAL("node.ObjectPathMissingArrayElement") - .set_body_method(&ObjectPathNode::MissingArrayElement); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_method("node.ObjectPathMissingArrayElement", + &ObjectPathNode::MissingArrayElement); +}); // --- MapValue --- @@ -123,7 +145,10 @@ ObjectPath ObjectPathNode::MapValue(Any key) const { return ObjectPath(make_object(this, std::move(key))); } -TVM_FFI_REGISTER_GLOBAL("node.ObjectPathMapValue").set_body_method(&ObjectPathNode::MapValue); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_method("node.ObjectPathMapValue", &ObjectPathNode::MapValue); +}); // --- MissingMapEntry --- @@ -131,8 +156,10 @@ ObjectPath ObjectPathNode::MissingMapEntry() const { return ObjectPath(make_object(this)); } -TVM_FFI_REGISTER_GLOBAL("node.ObjectPathMissingMapEntry") - .set_body_method(&ObjectPathNode::MissingMapEntry); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_method("node.ObjectPathMissingMapEntry", &ObjectPathNode::MissingMapEntry); +}); // --- PathsEqual ---- @@ -158,7 +185,10 @@ bool ObjectPathNode::PathsEqual(const ObjectPath& other) const { return lhs == nullptr && rhs == nullptr; } -TVM_FFI_REGISTER_GLOBAL("node.ObjectPathEqual").set_body_method(&ObjectPathNode::PathsEqual); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_method("node.ObjectPathEqual", &ObjectPathNode::PathsEqual); +}); // --- Repr --- @@ -191,7 +221,10 @@ const ObjectPathNode* ObjectPathNode::ParentNode() const { return ObjectPath(make_object(name)); } -TVM_FFI_REGISTER_GLOBAL("node.ObjectPathRoot").set_body_typed(ObjectPath::Root); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("node.ObjectPathRoot", ObjectPath::Root); +}); // ============== Individual path classes ============== diff --git a/src/node/reflection.cc b/src/node/reflection.cc index 8dfb1ebfc7bc..baed335e6e45 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -175,11 +175,13 @@ void MakeNode(const ffi::PackedArgs& args, ffi::Any* rv) { *rv = ReflectionVTable::Global()->CreateObject(args[0].cast(), args.Slice(1)); } -TVM_FFI_REGISTER_GLOBAL("node.NodeGetAttr").set_body_packed(NodeGetAttr); - -TVM_FFI_REGISTER_GLOBAL("node.NodeListAttrNames").set_body_packed(NodeListAttrNames); - -TVM_FFI_REGISTER_GLOBAL("node.MakeNode").set_body_packed(MakeNode); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("node.NodeGetAttr", NodeGetAttr) + .def_packed("node.NodeListAttrNames", NodeListAttrNames) + .def_packed("node.MakeNode", MakeNode); +}); Optional GetAttrKeyByAddress(const Object* object, const void* attr_address) { const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(object->type_index()); diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index 69cb05c12106..ff16406dd9dd 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -22,6 +22,7 @@ * \file node/repr_printer.cc */ #include +#include #include #include @@ -101,9 +102,12 @@ void Dump(const runtime::ObjectRef& n) { std::cerr << n << "\n"; } void Dump(const runtime::Object* n) { Dump(runtime::GetRef(n)); } -TVM_FFI_REGISTER_GLOBAL("node.AsRepr").set_body_typed([](ffi::Any obj) { - std::ostringstream os; - os << obj; - return os.str(); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("node.AsRepr", [](ffi::Any obj) { + std::ostringstream os; + os << obj; + return os.str(); + }); }); } // namespace tvm diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index 518da92baf65..3cc1f018f05a 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include #include @@ -140,9 +141,12 @@ Array PrinterConfigNode::GetBuiltinKeywords() { } TVM_REGISTER_NODE_TYPE(PrinterConfigNode); -TVM_FFI_REGISTER_GLOBAL("node.PrinterConfig").set_body_typed([](Map config_dict) { - return PrinterConfig(config_dict); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("node.PrinterConfig", + [](Map config_dict) { return PrinterConfig(config_dict); }) + .def("node.TVMScriptPrinterScript", TVMScriptPrinter::Script); }); -TVM_FFI_REGISTER_GLOBAL("node.TVMScriptPrinterScript").set_body_typed(TVMScriptPrinter::Script); } // namespace tvm diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 30eed9d817de..e5618ea2faf6 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -785,7 +785,8 @@ Any LoadJSON(std::string json_str) { return nodes.at(jgraph.root); } -TVM_FFI_REGISTER_GLOBAL("node.SaveJSON").set_body_typed(SaveJSON); - -TVM_FFI_REGISTER_GLOBAL("node.LoadJSON").set_body_typed(LoadJSON); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("node.SaveJSON", SaveJSON).def("node.LoadJSON", LoadJSON); +}); } // namespace tvm diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index d1163269a8b3..c0a3992010b7 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -20,6 +20,7 @@ * \file src/node/structural_equal.cc */ #include +#include #include #include #include @@ -36,15 +37,14 @@ namespace tvm { TVM_REGISTER_OBJECT_TYPE(ObjectPathPairNode); -TVM_FFI_REGISTER_GLOBAL("node.ObjectPathPairLhsPath") - .set_body_typed([](const ObjectPathPair& object_path_pair) { - return object_path_pair->lhs_path; - }); - -TVM_FFI_REGISTER_GLOBAL("node.ObjectPathPairRhsPath") - .set_body_typed([](const ObjectPathPair& object_path_pair) { - return object_path_pair->rhs_path; - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("node.ObjectPathPairLhsPath", + [](const ObjectPathPair& object_path_pair) { return object_path_pair->lhs_path; }) + .def("node.ObjectPathPairRhsPath", + [](const ObjectPathPair& object_path_pair) { return object_path_pair->rhs_path; }); +}); ObjectPathPairNode::ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path) : lhs_path(std::move(lhs_path)), rhs_path(std::move(rhs_path)) {} @@ -599,27 +599,30 @@ bool SEqualHandlerDefault::DispatchSEqualReduce(const ObjectRef& lhs, const Obje return impl->DispatchSEqualReduce(lhs, rhs, map_free_vars, current_paths); } -TVM_FFI_REGISTER_GLOBAL("node.StructuralEqual") - .set_body_typed([](const Any& lhs, const Any& rhs, bool assert_mode, bool map_free_vars) { - // If we are asserting on failure, then the `defer_fails` option - // should be enabled, to provide better error messages. For - // example, if the number of bindings in a `relax::BindingBlock` - // differs, highlighting the first difference rather than the - // entire block. - bool defer_fails = assert_mode; - Optional first_mismatch; - return SEqualHandlerDefault(assert_mode, &first_mismatch, defer_fails) - .Equal(lhs, rhs, map_free_vars); - }); - -TVM_FFI_REGISTER_GLOBAL("node.GetFirstStructuralMismatch") - .set_body_typed([](const Any& lhs, const Any& rhs, bool map_free_vars) { - Optional first_mismatch; - bool equal = - SEqualHandlerDefault(false, &first_mismatch, true).Equal(lhs, rhs, map_free_vars); - ICHECK(equal == !first_mismatch.defined()); - return first_mismatch; - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("node.StructuralEqual", + [](const Any& lhs, const Any& rhs, bool assert_mode, bool map_free_vars) { + // If we are asserting on failure, then the `defer_fails` option + // should be enabled, to provide better error messages. For + // example, if the number of bindings in a `relax::BindingBlock` + // differs, highlighting the first difference rather than the + // entire block. + bool defer_fails = assert_mode; + Optional first_mismatch; + return SEqualHandlerDefault(assert_mode, &first_mismatch, defer_fails) + .Equal(lhs, rhs, map_free_vars); + }) + .def("node.GetFirstStructuralMismatch", + [](const Any& lhs, const Any& rhs, bool map_free_vars) { + Optional first_mismatch; + bool equal = + SEqualHandlerDefault(false, &first_mismatch, true).Equal(lhs, rhs, map_free_vars); + ICHECK(equal == !first_mismatch.defined()); + return first_mismatch; + }); +}); bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_params) const { diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 1f1d476d5cf3..5f11448f85d2 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -21,6 +21,7 @@ */ #include #include +#include #include #include #include @@ -291,11 +292,14 @@ void SHashHandlerDefault::DispatchSHash(const ObjectRef& key, bool map_free_vars impl->DispatchSHash(key, map_free_vars); } -TVM_FFI_REGISTER_GLOBAL("node.StructuralHash") - .set_body_typed([](const Any& object, bool map_free_vars) -> int64_t { - uint64_t hashed_value = SHashHandlerDefault().Hash(object, map_free_vars); - return static_cast(hashed_value); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("node.StructuralHash", + [](const Any& object, bool map_free_vars) -> int64_t { + uint64_t hashed_value = SHashHandlerDefault().Hash(object, map_free_vars); + return static_cast(hashed_value); + }); +}); uint64_t StructuralHash::operator()(const ObjectRef& object) const { return SHashHandlerDefault().Hash(object, false); diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc index 98122d1e1ec8..8580ec61128a 100644 --- a/src/relax/analysis/analysis.cc +++ b/src/relax/analysis/analysis.cc @@ -24,6 +24,7 @@ * \brief Analysis functions for Relax. */ +#include #include #include #include @@ -197,15 +198,15 @@ bool ContainsImpureCall(const Expr& expr, const Optional& own_name) { return FindImpureCall(expr, own_name).defined(); } -TVM_FFI_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars); - -TVM_FFI_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars); - -TVM_FFI_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars); - -TVM_FFI_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars); - -TVM_FFI_REGISTER_GLOBAL("relax.analysis.contains_impure_call").set_body_typed(ContainsImpureCall); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.analysis.free_vars", FreeVars) + .def("relax.analysis.bound_vars", BoundVars) + .def("relax.analysis.all_vars", AllVars) + .def("relax.analysis.all_global_vars", AllGlobalVars) + .def("relax.analysis.contains_impure_call", ContainsImpureCall); +}); } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/computable_at_compile_time.cc b/src/relax/analysis/computable_at_compile_time.cc index 5825895db7d6..bd9b56633584 100644 --- a/src/relax/analysis/computable_at_compile_time.cc +++ b/src/relax/analysis/computable_at_compile_time.cc @@ -23,6 +23,7 @@ * \brief Utilities for identifying potentially compile-time variables */ +#include #include #include @@ -92,8 +93,10 @@ Array ComputableAtCompileTime(const Function& func) { return CompileTimeCollector::Collect(func); } -TVM_FFI_REGISTER_GLOBAL("relax.analysis.computable_at_compile_time") - .set_body_typed(ComputableAtCompileTime); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.analysis.computable_at_compile_time", ComputableAtCompileTime); +}); } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/detect_recursion.cc b/src/relax/analysis/detect_recursion.cc index 48ec7880b172..d3af3a78de27 100644 --- a/src/relax/analysis/detect_recursion.cc +++ b/src/relax/analysis/detect_recursion.cc @@ -24,6 +24,7 @@ * \brief Analysis to detect global recursive or mutually recursive functions. */ +#include #include #include #include @@ -392,7 +393,10 @@ tvm::Array> DetectRecursion(const IRModule& m) { return ret; } -TVM_FFI_REGISTER_GLOBAL("relax.analysis.detect_recursion").set_body_typed(DetectRecursion); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.analysis.detect_recursion", DetectRecursion); +}); } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/layout_transformation.cc b/src/relax/analysis/layout_transformation.cc index ab32abab5bea..f0d5620dd2f6 100644 --- a/src/relax/analysis/layout_transformation.cc +++ b/src/relax/analysis/layout_transformation.cc @@ -24,6 +24,7 @@ */ #include #include +#include #include #include #include @@ -614,10 +615,13 @@ Map> SuggestLayoutTransforms( return analyzer.GetSuggestedTransforms(); } -TVM_FFI_REGISTER_GLOBAL(("relax.analysis.suggest_layout_transforms")) - .set_body_typed([](PrimFunc fn, Array write_buffer_transformations) { - return SuggestLayoutTransforms(fn, write_buffer_transformations); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.analysis.suggest_layout_transforms", + [](PrimFunc fn, Array write_buffer_transformations) { + return SuggestLayoutTransforms(fn, write_buffer_transformations); + }); +}); } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index e09f061001f9..1888a7b27df7 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -23,6 +23,7 @@ * * \note Update this file when you added a new StructInfo. */ +#include #include #include #include @@ -72,8 +73,10 @@ class StaticTypeDeriver : public StructInfoFunctor { Type GetStaticType(const StructInfo& info) { return StaticTypeDeriver()(info); } -TVM_FFI_REGISTER_GLOBAL("relax.analysis.GetStaticType").set_body_typed([](const StructInfo& info) { - return GetStaticType(info); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.analysis.GetStaticType", + [](const StructInfo& info) { return GetStaticType(info); }); }); //-------------------------- @@ -285,11 +288,14 @@ StructInfo EraseToWellDefined(const StructInfo& info, Map sh return EraseToWellDefined(info, f_shape_var_map, f_var_map, ana); } -TVM_FFI_REGISTER_GLOBAL("relax.analysis.EraseToWellDefined") - .set_body_typed([](const StructInfo& info, Map shape_var_map, - Map var_map) { - return EraseToWellDefined(info, shape_var_map, var_map); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "relax.analysis.EraseToWellDefined", + [](const StructInfo& info, Map shape_var_map, Map var_map) { + return EraseToWellDefined(info, shape_var_map, var_map); + }); +}); //-------------------------- // IsBaseOf @@ -595,19 +601,24 @@ BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& de } } -TVM_FFI_REGISTER_GLOBAL("relax.analysis.StructInfoBaseCheck") - .set_body_typed([](const StructInfo& base, const StructInfo& derived) -> int { - return static_cast(StructInfoBaseCheck(base, derived)); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.analysis.StructInfoBaseCheck", + [](const StructInfo& base, const StructInfo& derived) -> int { + return static_cast(StructInfoBaseCheck(base, derived)); + }); +}); bool IsBaseOf(const StructInfo& base, const StructInfo& derived, arith::Analyzer* ana) { return StructInfoBaseCheck(base, derived, ana) == BaseCheckResult::kPass; } -TVM_FFI_REGISTER_GLOBAL("relax.StructInfoIsBaseOf") - .set_body_typed([](const StructInfo& base, const StructInfo& derived) { - return IsBaseOf(base, derived); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "relax.StructInfoIsBaseOf", + [](const StructInfo& base, const StructInfo& derived) { return IsBaseOf(base, derived); }); +}); class StructInfoBasePreconditionCollector : public StructInfoFunctor { @@ -955,10 +966,13 @@ StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call } } -TVM_FFI_REGISTER_GLOBAL("relax.analysis.DeriveCallRetStructInfo") - .set_body_typed([](const FuncStructInfo& finfo, const Call& call, const BlockBuilder& ctx) { - return DeriveCallRetStructInfo(finfo, call, ctx); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.analysis.DeriveCallRetStructInfo", + [](const FuncStructInfo& finfo, const Call& call, const BlockBuilder& ctx) { + return DeriveCallRetStructInfo(finfo, call, ctx); + }); +}); //-------------------------- // UnifyToLCA @@ -1158,10 +1172,12 @@ StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, arith::An } } -TVM_FFI_REGISTER_GLOBAL("relax.analysis.StructInfoLCA") - .set_body_typed([](const StructInfo& lhs, const StructInfo& rhs) { - return StructInfoLCA(lhs, rhs); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "relax.analysis.StructInfoLCA", + [](const StructInfo& lhs, const StructInfo& rhs) { return StructInfoLCA(lhs, rhs); }); +}); //-------------------------- // TIRVarsInStructInfo @@ -1241,10 +1257,12 @@ Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo) { return detector.GetTIRVars(); } -TVM_FFI_REGISTER_GLOBAL("relax.analysis.TIRVarsInStructInfo").set_body_typed(TIRVarsInStructInfo); - -TVM_FFI_REGISTER_GLOBAL("relax.analysis.DefinableTIRVarsInStructInfo") - .set_body_typed(DefinableTIRVarsInStructInfo); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.analysis.TIRVarsInStructInfo", TIRVarsInStructInfo) + .def("relax.analysis.DefinableTIRVarsInStructInfo", DefinableTIRVarsInStructInfo); +}); class NonNegativeExpressionCollector : relax::StructInfoVisitor { public: @@ -1288,8 +1306,11 @@ Array CollectNonNegativeExpressions(const StructInfo& sinfo) { return NonNegativeExpressionCollector::Collect(sinfo); } -TVM_FFI_REGISTER_GLOBAL("relax.analysis.CollectNonNegativeExpressions") - .set_body_typed(CollectNonNegativeExpressions); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.analysis.CollectNonNegativeExpressions", + CollectNonNegativeExpressions); +}); class SymbolicVarCollector : public relax::ExprVisitor, public relax::StructInfoVisitor, @@ -1436,9 +1457,12 @@ Array DefinedSymbolicVars(const Expr& expr) { } Array FreeSymbolicVars(const Expr& expr) { return SymbolicVarCollector::Free(expr); } -TVM_FFI_REGISTER_GLOBAL("relax.analysis.DefinedSymbolicVars").set_body_typed(DefinedSymbolicVars); - -TVM_FFI_REGISTER_GLOBAL("relax.analysis.FreeSymbolicVars").set_body_typed(FreeSymbolicVars); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.analysis.DefinedSymbolicVars", DefinedSymbolicVars) + .def("relax.analysis.FreeSymbolicVars", FreeSymbolicVars); +}); } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc index 0845ec092fe2..2f896263e950 100644 --- a/src/relax/analysis/tir_op_pattern_kind.cc +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -18,6 +18,7 @@ */ #include +#include #include #include #include @@ -537,7 +538,10 @@ bool HasReshapePattern(const PrimFunc& func) { return ReshapeDetector::Detect(src_buffer, dst_buffer, func->body); } -TVM_FFI_REGISTER_GLOBAL("relax.analysis.has_reshape_pattern").set_body_typed(HasReshapePattern); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.analysis.has_reshape_pattern", HasReshapePattern); +}); } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc index 2f04d8659405..e2c201abeb7f 100644 --- a/src/relax/analysis/udchain.cc +++ b/src/relax/analysis/udchain.cc @@ -22,6 +22,7 @@ * \brief Implementation of use-def analysis. */ +#include #include #include #include @@ -118,7 +119,10 @@ Map> DataflowBlockUseDef(const DataflowBlock& dfb) { return usage.downstream_usage; } -TVM_FFI_REGISTER_GLOBAL("relax.analysis.udchain").set_body_typed(DataflowBlockUseDef); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.analysis.udchain", DataflowBlockUseDef); +}); VarUsageInfo CollectVarUsage(const Expr& expr) { return UDChain::Collect(expr); } diff --git a/src/relax/analysis/var2value.cc b/src/relax/analysis/var2value.cc index a367d33ca4ff..d24067c8f460 100644 --- a/src/relax/analysis/var2value.cc +++ b/src/relax/analysis/var2value.cc @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include #include @@ -58,8 +59,10 @@ Map AnalyzeVar2Value(const IRModule& m) { return std::move(var2val_analysis.var2value_); } -TVM_FFI_REGISTER_GLOBAL(("relax.analysis.get_var2val")).set_body_typed([](const Function& f) { - return AnalyzeVar2Value(f); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.analysis.get_var2val", + [](const Function& f) { return AnalyzeVar2Value(f); }); }); class Name2BindingAnalysis : public relax::ExprVisitor { @@ -85,7 +88,10 @@ Map> NameToBinding(const Function& fn) { std::make_move_iterator(analysis.name2bindings_.end())); } -TVM_FFI_REGISTER_GLOBAL(("relax.analysis.name_to_binding")).set_body_typed(NameToBinding); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.analysis.name_to_binding", NameToBinding); +}); } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 315281cc007f..bd6d727c21dc 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -64,6 +64,7 @@ * 17. If the kForcePure attribute is set for a function, * that function's is_pure field must be true. */ +#include #include #include #include @@ -645,7 +646,10 @@ bool WellFormed(Variant obj, bool check_struct_info) { return WellFormedChecker::Check(obj, check_struct_info); } -TVM_FFI_REGISTER_GLOBAL(("relax.analysis.well_formed")).set_body_typed(WellFormed); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.analysis.well_formed", WellFormed); +}); } // namespace relax } // namespace tvm diff --git a/src/relax/backend/contrib/clml/codegen.cc b/src/relax/backend/contrib/clml/codegen.cc index faca48c4a61f..cd038597ee31 100644 --- a/src/relax/backend/contrib/clml/codegen.cc +++ b/src/relax/backend/contrib/clml/codegen.cc @@ -21,6 +21,7 @@ * \file src/relax/backend/contrib/clml/codegen.cc * \brief Implementation of the OpenCLML JSON serializer. */ +#include #include #include #include @@ -328,7 +329,10 @@ Array OpenCLMLCompiler(Array functions, Map #include #include @@ -125,7 +126,10 @@ Array CublasCompiler(Array functions, Map #include #include @@ -149,7 +150,10 @@ Array cuDNNCompiler(Array functions, Map headers) { - return CodegenResult(code, headers); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("contrib.cutlass.CodegenResult", [](String code, Array headers) { + return CodegenResult(code, headers); + }); +}); GenerateBodyOutput GenerateBody(const std::string& func_name, const std::string& ext_func_id, const std::vector& output_types, @@ -391,7 +393,10 @@ Array CUTLASSCompiler(Array functions, Map #include #include @@ -97,7 +98,10 @@ Array DNNLCompiler(Array functions, Map #include #include @@ -102,7 +103,10 @@ Array HipblasCompiler(Array functions, Map +#include #include #include #include @@ -264,7 +265,10 @@ Array NNAPICompiler(Array functions, Map #include #include // TODO(sunggg): add operator attribute when it's ready @@ -243,7 +244,10 @@ Array TensorRTCompiler(Array functions, Map GetTensorRTVersion() { #endif // TVM_GRAPH_EXECUTOR_TENSORRT } -TVM_FFI_REGISTER_GLOBAL("relax.is_tensorrt_runtime_enabled") - .set_body_typed(IsTensorRTRuntimeEnabled); -TVM_FFI_REGISTER_GLOBAL("relax.get_tensorrt_version").set_body_typed(GetTensorRTVersion); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.is_tensorrt_runtime_enabled", IsTensorRTRuntimeEnabled) + .def("relax.get_tensorrt_version", GetTensorRTVersion); +}); } // namespace contrib } // namespace relax diff --git a/src/relax/backend/contrib/utils.cc b/src/relax/backend/contrib/utils.cc index 6574ccc37a15..41bd11d60e7f 100644 --- a/src/relax/backend/contrib/utils.cc +++ b/src/relax/backend/contrib/utils.cc @@ -18,6 +18,7 @@ */ #include "utils.h" +#include #include #include #include @@ -75,7 +76,10 @@ bool EndsWithPattern(const std::string& str, const std::string& pattern) { return str.compare(str.length() - pattern.length(), pattern.length(), pattern) == 0; } -TVM_FFI_REGISTER_GLOBAL("relax.contrib.extract_arg_idx").set_body_typed(ExtractArgIdx); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.contrib.extract_arg_idx", ExtractArgIdx); +}); } // namespace backend } // namespace relax diff --git a/src/relax/backend/pattern_registry.cc b/src/relax/backend/pattern_registry.cc index 840b44c12838..90b5afbd4d38 100644 --- a/src/relax/backend/pattern_registry.cc +++ b/src/relax/backend/pattern_registry.cc @@ -19,6 +19,8 @@ #include "./pattern_registry.h" +#include + #include "../../support/utils.h" namespace tvm { @@ -67,11 +69,14 @@ Optional GetPattern(const String& pattern_name) { return std::nullopt; } -TVM_FFI_REGISTER_GLOBAL("relax.backend.RegisterPatterns").set_body_typed(RegisterPatterns); -TVM_FFI_REGISTER_GLOBAL("relax.backend.RemovePatterns").set_body_typed(RemovePatterns); -TVM_FFI_REGISTER_GLOBAL("relax.backend.GetPatternsWithPrefix") - .set_body_typed(GetPatternsWithPrefix); -TVM_FFI_REGISTER_GLOBAL("relax.backend.GetPattern").set_body_typed(GetPattern); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.backend.RegisterPatterns", RegisterPatterns) + .def("relax.backend.RemovePatterns", RemovePatterns) + .def("relax.backend.GetPatternsWithPrefix", GetPatternsWithPrefix) + .def("relax.backend.GetPattern", GetPattern); +}); } // namespace backend } // namespace relax diff --git a/src/relax/backend/task_extraction.cc b/src/relax/backend/task_extraction.cc index 686d24de62b2..7b476b88b746 100644 --- a/src/relax/backend/task_extraction.cc +++ b/src/relax/backend/task_extraction.cc @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include #include @@ -139,10 +140,13 @@ class TaskExtractor : public ExprVisitor { std::optional normalize_mod_func_; }; -TVM_FFI_REGISTER_GLOBAL("relax.backend.MetaScheduleExtractTask") - .set_body_typed([](IRModule mod, Target target, String mod_eq_name) { - return TaskExtractor::ExtractTask(std::move(mod), std::move(target), std::move(mod_eq_name)); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.backend.MetaScheduleExtractTask", [](IRModule mod, Target target, + String mod_eq_name) { + return TaskExtractor::ExtractTask(std::move(mod), std::move(target), std::move(mod_eq_name)); + }); +}); } // namespace backend } // namespace relax diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 4cf2811922c8..4a76905f3fd4 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -21,6 +21,7 @@ * \file src/relax/backend/vm/codegen_vm.cc * \brief A codegen to generate VM executable from a Relax IRModule. */ +#include #include #include #include @@ -425,7 +426,10 @@ IRModule VMCodeGen(ExecBuilder exec_builder, IRModule mod) { return CodeGenVM::Run(exec_builder, mod); } -TVM_FFI_REGISTER_GLOBAL("relax.VMCodeGen").set_body_typed(VMCodeGen); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.VMCodeGen", VMCodeGen); +}); /*! * \brief Link the modules together, possibly create a constant module. @@ -490,7 +494,10 @@ Module VMLink(ExecBuilder builder, Target target, Optional lib, Array #include #include #include @@ -530,7 +531,10 @@ IRModule VMTIRCodeGen(ExecBuilder exec_builder, IRModule mod) { return CodeGenVMTIR::Run(exec_builder, mod); } -TVM_FFI_REGISTER_GLOBAL("relax.VMTIRCodeGen").set_body_typed(VMTIRCodeGen); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.VMTIRCodeGen", VMTIRCodeGen); +}); } // namespace codegen_vm } // namespace relax diff --git a/src/relax/backend/vm/exec_builder.cc b/src/relax/backend/vm/exec_builder.cc index b13f4da6dae0..86ac536a9f8e 100644 --- a/src/relax/backend/vm/exec_builder.cc +++ b/src/relax/backend/vm/exec_builder.cc @@ -20,6 +20,7 @@ /*! * \file src/relax/backend/vm/exec_builder.cc */ +#include #include #include @@ -329,74 +330,63 @@ void ExecBuilderNode::Formalize() { } } -TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderCreate").set_body_typed(ExecBuilderNode::Create); - -TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderConvertConstant") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - ExecBuilder builder = args[0].cast(); - ffi::Any rt; - rt = args[1]; - *ret = builder->ConvertConstant(rt).data(); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderEmitFunction") - .set_body_typed([](ExecBuilder builder, String func, int64_t num_inputs, - Optional> param_names) { - builder->EmitFunction(func, num_inputs, param_names); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderEndFunction") - .set_body_method(&ExecBuilderNode::EndFunction); - -TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderDeclareFunction") - .set_body_typed([](ExecBuilder builder, String name, int32_t kind) { - builder->DeclareFunction(name, static_cast(kind)); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderEmitCall") - .set_body_typed([](ExecBuilder builder, String name, Array args, int64_t dst) { - std::vector args_; - for (size_t i = 0; i < args.size(); ++i) { - args_.push_back(Instruction::Arg::FromData(args[i]->value)); - } - auto dst_ = Instruction::Arg::Register(dst); - builder->EmitCall(name, args_, dst_.value()); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderEmitRet") - .set_body_typed([](ExecBuilder builder, int64_t data) { - builder->EmitRet(Instruction::Arg::FromData(data)); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderEmitGoto").set_body_method(&ExecBuilderNode::EmitGoto); - -TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderEmitIf") - .set_body_typed([](ExecBuilder builder, int64_t data, vm::Index false_offset) { - builder->EmitIf(Instruction::Arg::FromData(data), false_offset); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderR") - .set_body_typed([](ExecBuilder builder, int64_t value) { - return Instruction::Arg::Register(value).data(); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderImm") - .set_body_typed([](ExecBuilder builder, int64_t value) { - return Instruction::Arg::Immediate(value).data(); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderC") - .set_body_typed([](ExecBuilder builder, int64_t value) { - return Instruction::Arg::ConstIdx(value).data(); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderF").set_body_typed([](ExecBuilder builder, String value) { - return builder->GetFunction(value).data(); -}); - -TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderGet").set_body_typed([](ExecBuilder builder) { - ObjectPtr p_exec = builder->Get(); - return runtime::Module(p_exec); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.ExecBuilderCreate", ExecBuilderNode::Create) + .def_packed("relax.ExecBuilderConvertConstant", + [](ffi::PackedArgs args, ffi::Any* ret) { + ExecBuilder builder = args[0].cast(); + ffi::Any rt; + rt = args[1]; + *ret = builder->ConvertConstant(rt).data(); + }) + .def("relax.ExecBuilderEmitFunction", + [](ExecBuilder builder, String func, int64_t num_inputs, + Optional> param_names) { + builder->EmitFunction(func, num_inputs, param_names); + }) + .def_method("relax.ExecBuilderEndFunction", &ExecBuilderNode::EndFunction) + .def("relax.ExecBuilderDeclareFunction", + [](ExecBuilder builder, String name, int32_t kind) { + builder->DeclareFunction(name, static_cast(kind)); + }) + .def("relax.ExecBuilderEmitCall", + [](ExecBuilder builder, String name, Array args, int64_t dst) { + std::vector args_; + for (size_t i = 0; i < args.size(); ++i) { + args_.push_back(Instruction::Arg::FromData(args[i]->value)); + } + auto dst_ = Instruction::Arg::Register(dst); + builder->EmitCall(name, args_, dst_.value()); + }) + .def("relax.ExecBuilderEmitRet", + [](ExecBuilder builder, int64_t data) { + builder->EmitRet(Instruction::Arg::FromData(data)); + }) + .def_method("relax.ExecBuilderEmitGoto", &ExecBuilderNode::EmitGoto) + .def("relax.ExecBuilderEmitIf", + [](ExecBuilder builder, int64_t data, vm::Index false_offset) { + builder->EmitIf(Instruction::Arg::FromData(data), false_offset); + }) + .def("relax.ExecBuilderR", + [](ExecBuilder builder, int64_t value) { + return Instruction::Arg::Register(value).data(); + }) + .def("relax.ExecBuilderImm", + [](ExecBuilder builder, int64_t value) { + return Instruction::Arg::Immediate(value).data(); + }) + .def("relax.ExecBuilderC", + [](ExecBuilder builder, int64_t value) { + return Instruction::Arg::ConstIdx(value).data(); + }) + .def("relax.ExecBuilderF", + [](ExecBuilder builder, String value) { return builder->GetFunction(value).data(); }) + .def("relax.ExecBuilderGet", [](ExecBuilder builder) { + ObjectPtr p_exec = builder->Get(); + return runtime::Module(p_exec); + }); }); } // namespace relax diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index 7757195bcb1d..c569920c57f7 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -20,6 +20,7 @@ * \file src/relax/backend/vm/lower_runtime_builtin.cc * \brief Lowers most builtin functions and packed calls. */ +#include #include #include #include @@ -231,7 +232,10 @@ Pass LowerRuntimeBuiltin() { return CreateFunctionPass(pass_func, 0, "LowerRuntimeBuiltin", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.LowerRuntimeBuiltin").set_body_typed(LowerRuntimeBuiltin); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.LowerRuntimeBuiltin", LowerRuntimeBuiltin); +}); } // namespace transform } // namespace relax diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 986626a6eae0..a964ca1e1a98 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -20,6 +20,7 @@ * \file src/relax/backend/vm/vm_shape_lower.cc * \brief Lower the function boundary type checks and symbolic shape computations. */ +#include #include #include #include @@ -813,8 +814,10 @@ Pass VMShapeLower(bool emit_err_ctx) { return CreateModulePass(pass_func, 0, "VMShapeLower", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.VMShapeLower").set_body_typed([](bool emit_err_ctx) { - return VMShapeLower(emit_err_ctx); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.VMShapeLower", + [](bool emit_err_ctx) { return VMShapeLower(emit_err_ctx); }); }); } // namespace transform diff --git a/src/relax/distributed/global_info.cc b/src/relax/distributed/global_info.cc index a20c25102734..57d2e652dc27 100644 --- a/src/relax/distributed/global_info.cc +++ b/src/relax/distributed/global_info.cc @@ -60,13 +60,17 @@ DeviceMesh::DeviceMesh(ffi::Shape shape, Range device_range) { } TVM_REGISTER_NODE_TYPE(DeviceMeshNode); -TVM_FFI_REGISTER_GLOBAL("relax.distributed.DeviceMesh") - .set_body_typed([](ffi::Shape shape, Array device_ids, Optional device_range) { - if (device_range.defined()) - return DeviceMesh(shape, device_range.value()); - else - return DeviceMesh(shape, device_ids); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "relax.distributed.DeviceMesh", + [](ffi::Shape shape, Array device_ids, Optional device_range) { + if (device_range.defined()) + return DeviceMesh(shape, device_range.value()); + else + return DeviceMesh(shape, device_ids); + }); +}); } // namespace distributed } // namespace relax diff --git a/src/relax/distributed/struct_info.cc b/src/relax/distributed/struct_info.cc index 93c5d75b5de1..4b014d2a4a63 100644 --- a/src/relax/distributed/struct_info.cc +++ b/src/relax/distributed/struct_info.cc @@ -22,6 +22,7 @@ * \brief Relax dtensor struct info. */ +#include #include namespace tvm { namespace relax { @@ -49,12 +50,11 @@ PlacementSpec PlacementSpec::Replica() { TVM_REGISTER_NODE_TYPE(PlacementSpecNode); -TVM_FFI_REGISTER_GLOBAL("relax.distributed.Sharding").set_body_typed([](int axis) { - return PlacementSpec::Sharding(axis); -}); - -TVM_FFI_REGISTER_GLOBAL("relax.distributed.Replica").set_body_typed([]() { - return PlacementSpec::Replica(); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.distributed.Sharding", [](int axis) { return PlacementSpec::Sharding(axis); }) + .def("relax.distributed.Replica", []() { return PlacementSpec::Replica(); }); }); String PlacementNode::ToString() const { @@ -112,9 +112,13 @@ Placement Placement::FromText(String text_repr) { } TVM_REGISTER_NODE_TYPE(PlacementNode); -TVM_FFI_REGISTER_GLOBAL("relax.distributed.PlacementFromText").set_body_typed(Placement::FromText); -TVM_FFI_REGISTER_GLOBAL("relax.distributed.Placement") - .set_body_typed([](Array dim_specs) { return Placement(dim_specs); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.distributed.PlacementFromText", Placement::FromText) + .def("relax.distributed.Placement", + [](Array dim_specs) { return Placement(dim_specs); }); +}); // DTensor DTensorStructInfo::DTensorStructInfo(TensorStructInfo tensor_sinfo, DeviceMesh device_mesh, @@ -136,11 +140,14 @@ DTensorStructInfo::DTensorStructInfo(TensorStructInfo tensor_sinfo, DeviceMesh d TVM_REGISTER_NODE_TYPE(DTensorStructInfoNode); -TVM_FFI_REGISTER_GLOBAL("relax.distributed.DTensorStructInfo") - .set_body_typed([](TensorStructInfo tensor_sinfo, DeviceMesh device_mesh, Placement placement, - Span span) { - return DTensorStructInfo(tensor_sinfo, device_mesh, placement, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "relax.distributed.DTensorStructInfo", + [](TensorStructInfo tensor_sinfo, DeviceMesh device_mesh, Placement placement, Span span) { + return DTensorStructInfo(tensor_sinfo, device_mesh, placement, span); + }); +}); } // namespace distributed } // namespace relax diff --git a/src/relax/distributed/transform/legalize_redistribute.cc b/src/relax/distributed/transform/legalize_redistribute.cc index 1df1d2110ba9..cc5a26bc4733 100644 --- a/src/relax/distributed/transform/legalize_redistribute.cc +++ b/src/relax/distributed/transform/legalize_redistribute.cc @@ -22,6 +22,7 @@ * \brief Pass for legalizing redistribute op to ccl op. */ +#include #include #include #include @@ -115,8 +116,10 @@ Pass LegalizeRedistribute() { }; return CreateModulePass(pass_func, 1, "LegalizeRedistribute", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.distributed.transform.LegalizeRedistribute") - .set_body_typed(LegalizeRedistribute); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.distributed.transform.LegalizeRedistribute", LegalizeRedistribute); +}); } // namespace transform } // namespace distributed diff --git a/src/relax/distributed/transform/lower_distir.cc b/src/relax/distributed/transform/lower_distir.cc index e4f811b83d42..73ed61d45cf9 100644 --- a/src/relax/distributed/transform/lower_distir.cc +++ b/src/relax/distributed/transform/lower_distir.cc @@ -25,6 +25,7 @@ * inserting necessary broadcast and scatter for inputs. */ +#include #include #include #include @@ -262,7 +263,10 @@ Pass LowerDistIR() { auto pass_func = [=](IRModule m, PassContext pc) { return DistIRSharder::LowerDistIR(m); }; return CreateModulePass(pass_func, 1, "LowerDistIR", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.distributed.transform.LowerDistIR").set_body_typed(LowerDistIR); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.distributed.transform.LowerDistIR", LowerDistIR); +}); } // namespace transform } // namespace distributed diff --git a/src/relax/distributed/transform/lower_global_view_to_local_view.cc b/src/relax/distributed/transform/lower_global_view_to_local_view.cc index c8abe2b1d1b5..e2951fd76484 100644 --- a/src/relax/distributed/transform/lower_global_view_to_local_view.cc +++ b/src/relax/distributed/transform/lower_global_view_to_local_view.cc @@ -21,6 +21,7 @@ * \file tvm/relax/distributed/transform/lower_global_view_to_local_view.cc * \brief Pass for lowering global view TensorIR into local view */ +#include #include #include #include @@ -432,8 +433,11 @@ Pass LowerGlobalViewToLocalView() { auto pass_func = [=](IRModule m, PassContext pc) { return LowerTIRToLocalView(m).Lower(); }; return CreateModulePass(pass_func, 1, "LowerGlobalViewToLocalView", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.distributed.transform.LowerGlobalViewToLocalView") - .set_body_typed(LowerGlobalViewToLocalView); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.distributed.transform.LowerGlobalViewToLocalView", + LowerGlobalViewToLocalView); +}); } // namespace transform } // namespace distributed diff --git a/src/relax/distributed/transform/propagate_sharding.cc b/src/relax/distributed/transform/propagate_sharding.cc index f5f276c2b873..5e0972be7741 100644 --- a/src/relax/distributed/transform/propagate_sharding.cc +++ b/src/relax/distributed/transform/propagate_sharding.cc @@ -21,6 +21,7 @@ * \file tvm/relax/distributed/transform/propagate_sharding.cc * \brief Pass for propagating sharding information. */ +#include #include #include #include @@ -615,8 +616,10 @@ Pass PropagateSharding() { }; return CreateModulePass(pass_func, 1, "PropagateSharding", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.distributed.transform.PropagateSharding") - .set_body_typed(PropagateSharding); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.distributed.transform.PropagateSharding", PropagateSharding); +}); } // namespace transform } // namespace distributed diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc index c65fe1d0ddeb..72f1e1acbfd9 100644 --- a/src/relax/ir/binding_rewrite.cc +++ b/src/relax/ir/binding_rewrite.cc @@ -54,10 +54,12 @@ DataflowBlockRewrite::DataflowBlockRewrite(DataflowBlock dfb, Function root_fn) data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.DataflowBlockRewrite") - .set_body_typed([](DataflowBlock dfb, Function root_fn) { - return DataflowBlockRewrite(dfb, root_fn); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.DataflowBlockRewrite", [](DataflowBlock dfb, Function root_fn) { + return DataflowBlockRewrite(dfb, root_fn); + }); +}); void DataflowBlockRewriteNode::ReplaceAllUses(Var old_var, Var new_var) { class ReplaceAllUsePass : public ExprMutator { @@ -113,10 +115,13 @@ void DataflowBlockRewriteNode::ReplaceAllUses(Var old_var, Var new_var) { } } -TVM_FFI_REGISTER_GLOBAL("relax.dfb_rewrite_replace_all_uses") - .set_body_typed([](DataflowBlockRewrite rwt, Var old_var, Var new_var) { - rwt->ReplaceAllUses(old_var, new_var); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dfb_rewrite_replace_all_uses", + [](DataflowBlockRewrite rwt, Var old_var, Var new_var) { + rwt->ReplaceAllUses(old_var, new_var); + }); +}); class UpdateDFB : public ExprMutator { private: @@ -181,17 +186,20 @@ void DataflowBlockRewriteNode::Add(Binding binding) { } } -TVM_FFI_REGISTER_GLOBAL("relax.dfb_rewrite_add_binding") - .set_body_typed([](DataflowBlockRewrite rwt, Binding vb) { rwt->Add(vb); }); - -TVM_FFI_REGISTER_GLOBAL("relax.dfb_rewrite_add") - .set_body_typed([](DataflowBlockRewrite rwt, Expr expr, Optional name, bool is_dfvar) { - if (name.has_value()) { - rwt->Add(name.value(), expr, is_dfvar); - } else { - rwt->Add(expr, is_dfvar); - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.dfb_rewrite_add_binding", + [](DataflowBlockRewrite rwt, Binding vb) { rwt->Add(vb); }) + .def("relax.dfb_rewrite_add", + [](DataflowBlockRewrite rwt, Expr expr, Optional name, bool is_dfvar) { + if (name.has_value()) { + rwt->Add(name.value(), expr, is_dfvar); + } else { + rwt->Add(expr, is_dfvar); + } + }); +}); std::set GetUnusedVars(Map> users_map, Array fn_outputs) { std::vector unused; @@ -295,10 +303,13 @@ void DataflowBlockRewriteNode::RemoveUnused(Var unused, bool allow_undef) { to_users_.erase(unused); // update use-def chain. } -TVM_FFI_REGISTER_GLOBAL("relax.dfb_rewrite_remove_unused") - .set_body_typed([](DataflowBlockRewrite rwt, Var unused, bool allow_undef) { - rwt->RemoveUnused(unused, allow_undef); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dfb_rewrite_remove_unused", + [](DataflowBlockRewrite rwt, Var unused, bool allow_undef) { + rwt->RemoveUnused(unused, allow_undef); + }); +}); void DataflowBlockRewriteNode::RemoveAllUnused() { RemoveUnusedVars remover(to_users_, fn_outputs_); @@ -317,8 +328,11 @@ void DataflowBlockRewriteNode::RemoveAllUnused() { for (const auto& unused : remover.unused_vars) to_users_.erase(unused); } -TVM_FFI_REGISTER_GLOBAL("relax.dfb_rewrite_remove_all_unused") - .set_body_typed([](DataflowBlockRewrite rwt) { rwt->RemoveAllUnused(); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dfb_rewrite_remove_all_unused", + [](DataflowBlockRewrite rwt) { rwt->RemoveAllUnused(); }); +}); Expr RemoveAllUnused(Expr expr) { auto var_usage = CollectVarUsage(expr); @@ -337,7 +351,10 @@ Expr RemoveAllUnused(Expr expr) { return remover.VisitExpr(std::move(expr)); } -TVM_FFI_REGISTER_GLOBAL("relax.analysis.remove_all_unused").set_body_typed(RemoveAllUnused); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.analysis.remove_all_unused", RemoveAllUnused); +}); IRModule DataflowBlockRewriteNode::MutateIRModule(IRModule irmod) { BlockBuilder builder = BlockBuilder::Create(irmod); @@ -352,10 +369,12 @@ IRModule DataflowBlockRewriteNode::MutateIRModule(IRModule irmod) { return builder->GetContextIRModule(); } -TVM_FFI_REGISTER_GLOBAL("relax.dfb_rewrite_mutate_irmodule") - .set_body_typed([](DataflowBlockRewrite rwt, IRModule irmod) { - return rwt->MutateIRModule(irmod); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "relax.dfb_rewrite_mutate_irmodule", + [](DataflowBlockRewrite rwt, IRModule irmod) { return rwt->MutateIRModule(irmod); }); +}); } // namespace relax } // namespace tvm diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 8f4aee382be4..7990dc0d39de 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include #include @@ -1054,67 +1055,43 @@ BlockBuilder BlockBuilder::Create(Optional mod, //--------------------------------------- TVM_REGISTER_OBJECT_TYPE(BlockBuilderNode); -TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderCreate").set_body_typed([](Optional mod) { - return BlockBuilder::Create(mod); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.BlockBuilderCreate", + [](Optional mod) { return BlockBuilder::Create(mod); }) + .def_method("relax.BlockBuilderBeginDataflowBlock", &BlockBuilderNode::BeginDataflowBlock) + .def_method("relax.BlockBuilderBeginBindingBlock", &BlockBuilderNode::BeginBindingBlock) + .def_method("relax.BlockBuilderEndBlock", &BlockBuilderNode::EndBlock) + .def_method("relax.BlockBuilderNormalize", &BlockBuilderNode::Normalize) + .def("relax.BlockBuilderEmit", + [](BlockBuilder builder, Expr expr, String name_hint) { + return builder->Emit(expr, name_hint); + }) + .def("relax.BlockBuilderEmitMatchCast", + [](BlockBuilder builder, Expr value, StructInfo struct_info, String name_hint) { + return builder->EmitMatchCast(value, struct_info, name_hint); + }) + .def("relax.BlockBuilderEmitOutput", + [](BlockBuilder builder, const Expr& output, String name_hint) { + return builder->EmitOutput(output, name_hint); + }) + .def("relax.BlockBuilderEmitNormalized", + [](BlockBuilder builder, Binding binding) { return builder->EmitNormalized(binding); }) + .def("relax.BlockBuilderGetUniqueName", + [](BlockBuilder builder, String name_hint) { + return builder->name_supply()->FreshName(name_hint, /*add_prefix*/ false, + /*add_underscore*/ false); + }) + .def_method("relax.BlockBuilderAddFunction", &BlockBuilderNode::AddFunction) + .def_method("relax.BlockBuilderUpdateFunction", &BlockBuilderNode::UpdateFunction) + .def_method("relax.BlockBuilderGetContextIRModule", &BlockBuilderNode::GetContextIRModule) + .def_method("relax.BlockBuilderFinalize", &BlockBuilderNode::Finalize) + .def_method("relax.BlockBuilderCurrentBlockIsDataFlow", + &BlockBuilderNode::CurrentBlockIsDataFlow) + .def_method("relax.BlockBuilderLookupBinding", &BlockBuilderNode::LookupBinding) + .def_method("relax.BlockBuilderBeginScope", &BlockBuilderNode::BeginScope) + .def_method("relax.BlockBuilderEndScope", &BlockBuilderNode::EndScope); }); - -TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderBeginDataflowBlock") - .set_body_method(&BlockBuilderNode::BeginDataflowBlock); - -TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderBeginBindingBlock") - .set_body_method(&BlockBuilderNode::BeginBindingBlock); - -TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderEndBlock").set_body_method(&BlockBuilderNode::EndBlock); - -TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderNormalize") - .set_body_method(&BlockBuilderNode::Normalize); - -TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderEmit") - .set_body_typed([](BlockBuilder builder, Expr expr, String name_hint) { - return builder->Emit(expr, name_hint); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchCast") - .set_body_typed([](BlockBuilder builder, Expr value, StructInfo struct_info, String name_hint) { - return builder->EmitMatchCast(value, struct_info, name_hint); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput") - .set_body_typed([](BlockBuilder builder, const Expr& output, String name_hint) { - return builder->EmitOutput(output, name_hint); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderEmitNormalized") - .set_body_typed([](BlockBuilder builder, Binding binding) { - return builder->EmitNormalized(binding); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderGetUniqueName") - .set_body_typed([](BlockBuilder builder, String name_hint) { - return builder->name_supply()->FreshName(name_hint, /*add_prefix*/ false, - /*add_underscore*/ false); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderAddFunction") - .set_body_method(&BlockBuilderNode::AddFunction); - -TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderUpdateFunction") - .set_body_method(&BlockBuilderNode::UpdateFunction); - -TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderGetContextIRModule") - .set_body_method(&BlockBuilderNode::GetContextIRModule); - -TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderFinalize").set_body_method(&BlockBuilderNode::Finalize); - -TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderCurrentBlockIsDataFlow") - .set_body_method(&BlockBuilderNode::CurrentBlockIsDataFlow); - -TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderLookupBinding") - .set_body_method(&BlockBuilderNode::LookupBinding); - -TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderBeginScope") - .set_body_method(&BlockBuilderNode::BeginScope); - -TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderEndScope").set_body_method(&BlockBuilderNode::EndScope); } // namespace relax } // namespace tvm diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc index cdb97b30c1f6..d7934c142272 100644 --- a/src/relax/ir/dataflow_block_rewriter.cc +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -363,10 +363,12 @@ Optional> MatchGraph(const PatternContext& ctx, const Datafl return MatchGraph(ctx, dfb->bindings, AnalyzeVar2Value(dfb)); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.match_dfb") - .set_body_typed([](const PatternContext& ctx, const DataflowBlock& dfb) { - return MatchGraph(ctx, dfb); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "relax.dpl.match_dfb", + [](const PatternContext& ctx, const DataflowBlock& dfb) { return MatchGraph(ctx, dfb); }); +}); class PatternContextRewriterNode : public PatternMatchingRewriterNode { public: @@ -449,7 +451,10 @@ Function RewriteBindings( return Downcast(PatternContextRewriter(ctx, rewriter)(func)); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.rewrite_bindings", RewriteBindings); +}); TVM_FFI_STATIC_INIT_BLOCK({ PatternContextRewriterNode::RegisterReflection(); }); diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index c105180e31a4..8646612777ab 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -194,26 +194,28 @@ void RewriteSpec::Append(RewriteSpec other) { TVM_REGISTER_NODE_TYPE(PatternMatchingRewriterNode); -TVM_FFI_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterFromPattern") - .set_body_typed([](DFPattern pattern, - ffi::TypedFunction(Expr, Map)> func) { - return PatternMatchingRewriter::FromPattern(pattern, func); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterFromModule") - .set_body_typed([](IRModule mod) { return PatternMatchingRewriter::FromModule(mod); }); - -TVM_FFI_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterApply") - .set_body_typed([](PatternMatchingRewriter rewriter, - Variant obj) -> Variant { - if (auto expr = obj.as()) { - return rewriter(expr.value()); - } else if (auto mod = obj.as()) { - return rewriter(mod.value()); - } else { - LOG(FATAL) << "Unreachable: object does not contain either variant type"; - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.dpl.PatternMatchingRewriterFromPattern", + [](DFPattern pattern, + ffi::TypedFunction(Expr, Map)> func) { + return PatternMatchingRewriter::FromPattern(pattern, func); + }) + .def("relax.dpl.PatternMatchingRewriterFromModule", + [](IRModule mod) { return PatternMatchingRewriter::FromModule(mod); }) + .def("relax.dpl.PatternMatchingRewriterApply", + [](PatternMatchingRewriter rewriter, + Variant obj) -> Variant { + if (auto expr = obj.as()) { + return rewriter(expr.value()); + } else if (auto mod = obj.as()) { + return rewriter(mod.value()); + } else { + LOG(FATAL) << "Unreachable: object does not contain either variant type"; + } + }); +}); TVM_REGISTER_NODE_TYPE(ExprPatternRewriterNode); @@ -259,11 +261,14 @@ Optional ExprPatternRewriterNode::RewriteExpr(const Expr& expr, return std::nullopt; } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.PatternRewriter") - .set_body_typed([](DFPattern pattern, - ffi::TypedFunction(Expr, Map)> func) { - return ExprPatternRewriter(pattern, func); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "relax.dpl.PatternRewriter", + [](DFPattern pattern, ffi::TypedFunction(Expr, Map)> func) { + return ExprPatternRewriter(pattern, func); + }); +}); ExprPatternRewriter::ExprPatternRewriter( DFPattern pattern, ffi::TypedFunction(Expr, Map)> func, @@ -308,10 +313,13 @@ RewriteSpec OrRewriterNode::RewriteBindings(const Array& bindings) cons return lhs_match; } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.OrRewriter") - .set_body_typed([](PatternMatchingRewriter lhs, PatternMatchingRewriter rhs) { - return OrRewriter(lhs, rhs); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.OrRewriter", + [](PatternMatchingRewriter lhs, PatternMatchingRewriter rhs) { + return OrRewriter(lhs, rhs); + }); +}); OrRewriter::OrRewriter(PatternMatchingRewriter lhs, PatternMatchingRewriter rhs) { auto node = make_object(); @@ -603,11 +611,14 @@ std::optional> TupleRewriterNode::TryMatchByBindingIndex( return rewrites; } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.TupleRewriter") - .set_body_typed([](Array patterns, - ffi::TypedFunction(Expr, Map)> func) { - return TupleRewriter(patterns, func); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.TupleRewriter", + [](Array patterns, + ffi::TypedFunction(Expr, Map)> func) { + return TupleRewriter(patterns, func); + }); +}); TupleRewriter::TupleRewriter(Array patterns, ffi::TypedFunction(Expr, Map)> func, @@ -795,13 +806,19 @@ Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, return matcher.GetMemo(); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.extract_matched_expr").set_body_typed(ExtractMatchedExpr); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.extract_matched_expr", ExtractMatchedExpr); +}); bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings_opt) { return static_cast(ExtractMatchedExpr(pattern, expr, bindings_opt)); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.match_expr", MatchExpr); +}); /*! * \brief Apply pattern matching to each expression, replacing @@ -1073,7 +1090,10 @@ Function RewriteCall(const DFPattern& pat, return Downcast(PatternMatchingRewriter::FromPattern(pat, rewriter)(func)); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.rewrite_call", RewriteCall); +}); TVM_FFI_STATIC_INIT_BLOCK({ PatternMatchingRewriterNode::RegisterReflection(); diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index 48332de25f3a..f2861eb68489 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -22,6 +22,7 @@ * \brief The dataflow pattern language for Relax */ +#include #include #include @@ -68,8 +69,10 @@ ExternFuncPattern::ExternFuncPattern(String global_symbol) { n->global_symbol_ = std::move(global_symbol); data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.ExternFuncPattern").set_body_typed([](String global_symbol) { - return ExternFuncPattern(global_symbol); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.ExternFuncPattern", + [](String global_symbol) { return ExternFuncPattern(global_symbol); }); }); RELAX_PATTERN_PRINTER_DEF(ExternFuncPatternNode, [](auto p, auto node) { p->stream << "ExternFuncPattern(" << node->global_symbol() << ")"; @@ -81,16 +84,20 @@ VarPattern::VarPattern(String name_hint) { n->name = std::move(name_hint); data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.VarPattern").set_body_typed([](String name_hint) { - return VarPattern(name_hint); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.VarPattern", + [](String name_hint) { return VarPattern(name_hint); }); }); RELAX_PATTERN_PRINTER_DEF(VarPatternNode, [](auto p, auto node) { p->stream << "VarPattern(" << node->name_hint() << ")"; }); TVM_REGISTER_NODE_TYPE(DataflowVarPatternNode); -TVM_FFI_REGISTER_GLOBAL("relax.dpl.DataflowVarPattern").set_body_typed([](String name_hint) { - return DataflowVarPattern(name_hint); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.DataflowVarPattern", + [](String name_hint) { return DataflowVarPattern(name_hint); }); }); DataflowVarPattern::DataflowVarPattern(String name_hint) { ObjectPtr n = make_object(); @@ -107,8 +114,10 @@ GlobalVarPattern::GlobalVarPattern(String name_hint) { n->name = std::move(name_hint); data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.GlobalVarPattern").set_body_typed([](String name_hint) { - return GlobalVarPattern(name_hint); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.GlobalVarPattern", + [](String name_hint) { return GlobalVarPattern(name_hint); }); }); RELAX_PATTERN_PRINTER_DEF(GlobalVarPatternNode, [](auto p, auto node) { p->stream << "GlobalVarPattern(" << node->name_hint() << ")"; @@ -120,15 +129,19 @@ ExprPattern::ExprPattern(Expr expr) { n->expr = std::move(expr); data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.ExprPattern").set_body_typed([](Expr e) { - return ExprPattern(e); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.ExprPattern", [](Expr e) { return ExprPattern(e); }); }); RELAX_PATTERN_PRINTER_DEF(ExprPatternNode, [](auto p, auto node) { p->Print(node->expr); }); TVM_REGISTER_NODE_TYPE(ConstantPatternNode); -TVM_FFI_REGISTER_GLOBAL("relax.dpl.ConstantPattern").set_body_typed([]() { - auto c = ConstantPattern(make_object()); - return c; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.ConstantPattern", []() { + auto c = ConstantPattern(make_object()); + return c; + }); }); RELAX_PATTERN_PRINTER_DEF(ConstantPatternNode, [](auto p, auto node) { p->stream << "ConstantPattern()"; }); @@ -141,10 +154,13 @@ CallPattern::CallPattern(DFPattern op, Array args, bool varg_default_ n->varg_default_wildcard = varg_default_wildcard; data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.CallPattern") - .set_body_typed([](DFPattern op, Array args, bool varg_default_wildcard) { - return CallPattern(op, args, varg_default_wildcard); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.CallPattern", + [](DFPattern op, Array args, bool varg_default_wildcard) { + return CallPattern(op, args, varg_default_wildcard); + }); +}); RELAX_PATTERN_PRINTER_DEF(CallPatternNode, [](auto p, auto node) { p->stream << node->op << "("; for (size_t i = 0; i < node->args.size(); ++i) { @@ -164,8 +180,10 @@ PrimArrPattern::PrimArrPattern(Array arr) { n->fields = std::move(arr); data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.PrimArrPattern").set_body_typed([](Array arr) { - return PrimArrPattern(std::move(arr)); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.PrimArrPattern", + [](Array arr) { return PrimArrPattern(std::move(arr)); }); }); RELAX_PATTERN_PRINTER_DEF(PrimArrPatternNode, [](auto p, auto node) { p->stream << "PrimArrPattern(" << node->fields << ")"; @@ -178,10 +196,12 @@ FunctionPattern::FunctionPattern(Array params, DFPattern body) { n->body = std::move(body); data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.FunctionPattern") - .set_body_typed([](Array params, DFPattern body) { - return FunctionPattern(params, body); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.FunctionPattern", [](Array params, DFPattern body) { + return FunctionPattern(params, body); + }); +}); RELAX_PATTERN_PRINTER_DEF(FunctionPatternNode, [](auto p, auto node) { p->stream << "FunctionPattern(" << node->params << ", " << node->body << ")"; }); @@ -192,8 +212,10 @@ TuplePattern::TuplePattern(tvm::Array fields) { n->fields = std::move(fields); data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.TuplePattern").set_body_typed([](tvm::Array fields) { - return TuplePattern(fields); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.TuplePattern", + [](tvm::Array fields) { return TuplePattern(fields); }); }); RELAX_PATTERN_PRINTER_DEF(TuplePatternNode, [](auto p, auto node) { p->stream << "TuplePattern(" << node->fields << ")"; @@ -205,8 +227,11 @@ UnorderedTuplePattern::UnorderedTuplePattern(tvm::Array fields) { n->fields = std::move(fields); data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.UnorderedTuplePattern") - .set_body_typed([](tvm::Array fields) { return UnorderedTuplePattern(fields); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.UnorderedTuplePattern", + [](tvm::Array fields) { return UnorderedTuplePattern(fields); }); +}); RELAX_PATTERN_PRINTER_DEF(UnorderedTuplePatternNode, [](auto p, auto node) { p->stream << "UnorderedTuplePattern(" << node->fields << ")"; }); @@ -218,8 +243,12 @@ TupleGetItemPattern::TupleGetItemPattern(DFPattern tuple, int index) { n->index = index; data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.TupleGetItemPattern") - .set_body_typed([](DFPattern tuple, int index) { return TupleGetItemPattern(tuple, index); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.TupleGetItemPattern", [](DFPattern tuple, int index) { + return TupleGetItemPattern(tuple, index); + }); +}); RELAX_PATTERN_PRINTER_DEF(TupleGetItemPatternNode, [](auto p, auto node) { p->stream << "TupleGetItemPattern(" << node->tuple << ", " << node->index << ")"; }); @@ -231,8 +260,10 @@ AndPattern::AndPattern(DFPattern left, DFPattern right) { n->right = std::move(right); data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.AndPattern").set_body_typed([](DFPattern left, DFPattern right) { - return AndPattern(left, right); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.AndPattern", + [](DFPattern left, DFPattern right) { return AndPattern(left, right); }); }); RELAX_PATTERN_PRINTER_DEF(AndPatternNode, [](auto p, auto node) { p->stream << "AndPattern(" << node->left << " & " << node->right << ")"; @@ -245,8 +276,10 @@ OrPattern::OrPattern(DFPattern left, DFPattern right) { n->right = std::move(right); data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.OrPattern").set_body_typed([](DFPattern left, DFPattern right) { - return OrPattern(left, right); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.OrPattern", + [](DFPattern left, DFPattern right) { return OrPattern(left, right); }); }); RELAX_PATTERN_PRINTER_DEF(OrPatternNode, [](auto p, auto node) { p->stream << "OrPattern(" << node->left << " | " << node->right << ")"; @@ -258,16 +291,19 @@ NotPattern::NotPattern(DFPattern reject) { n->reject = std::move(reject); data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.NotPattern").set_body_typed([](DFPattern reject) { - return NotPattern(reject); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.NotPattern", + [](DFPattern reject) { return NotPattern(reject); }); }); RELAX_PATTERN_PRINTER_DEF(NotPatternNode, [](auto p, auto node) { p->stream << "!(" << node->reject << ")"; }); TVM_REGISTER_NODE_TYPE(WildcardPatternNode); WildcardPattern::WildcardPattern() { data_ = make_object(); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.WildcardPattern").set_body_typed([]() { - return WildcardPattern(); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.WildcardPattern", []() { return WildcardPattern(); }); }); RELAX_PATTERN_PRINTER_DEF(WildcardPatternNode, [](auto p, auto node) { p->stream << "*"; }); @@ -278,10 +314,13 @@ StructInfoPattern::StructInfoPattern(DFPattern pattern, StructInfo struct_info) n->struct_info = std::move(struct_info); data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.StructInfoPattern") - .set_body_typed([](DFPattern pattern, StructInfo struct_info) { - return StructInfoPattern(pattern, struct_info); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.StructInfoPattern", + [](DFPattern pattern, StructInfo struct_info) { + return StructInfoPattern(pattern, struct_info); + }); +}); RELAX_PATTERN_PRINTER_DEF(StructInfoPatternNode, [](auto p, auto node) { p->stream << "StructInfoPattern(" << node->pattern << " has relax StructInfo " << node->struct_info << ")"; @@ -294,10 +333,12 @@ ShapePattern::ShapePattern(DFPattern pattern, Array shape) { n->shape = std::move(shape); data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.ShapePattern") - .set_body_typed([](DFPattern pattern, Array shape) { - return ShapePattern(pattern, shape); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.ShapePattern", [](DFPattern pattern, Array shape) { + return ShapePattern(pattern, shape); + }); +}); RELAX_PATTERN_PRINTER_DEF(ShapePatternNode, [](auto p, auto node) { p->stream << "ShapePattern(" << node->pattern << " has shape " << node->shape << ")"; }); @@ -312,8 +353,10 @@ SameShapeConstraint::SameShapeConstraint(Array args) { ctx.value().add_constraint(*this); } } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.SameShapeConstraint").set_body_typed([](Array args) { - return SameShapeConstraint(args); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.SameShapeConstraint", + [](Array args) { return SameShapeConstraint(args); }); }); RELAX_PATTERN_PRINTER_DEF(SameShapeConstraintNode, [](auto p, auto node) { p->stream << "SameShapeConstraint("; @@ -333,10 +376,12 @@ DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) { n->dtype = std::move(dtype); data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.DataTypePattern") - .set_body_typed([](DFPattern pattern, DataType dtype) { - return DataTypePattern(pattern, dtype); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.DataTypePattern", [](DFPattern pattern, DataType dtype) { + return DataTypePattern(pattern, dtype); + }); +}); RELAX_PATTERN_PRINTER_DEF(DataTypePatternNode, [](auto p, auto node) { p->stream << "DataTypePattern(" << node->pattern << " has dtype " << node->dtype << ")"; }); @@ -348,8 +393,12 @@ AttrPattern::AttrPattern(DFPattern pattern, DictAttrs attrs) { n->attrs = std::move(attrs); data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.AttrPattern") - .set_body_typed([](DFPattern pattern, DictAttrs attrs) { return AttrPattern(pattern, attrs); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.AttrPattern", [](DFPattern pattern, DictAttrs attrs) { + return AttrPattern(pattern, attrs); + }); +}); RELAX_PATTERN_PRINTER_DEF(AttrPatternNode, [](auto p, auto node) { p->stream << "AttrPattern(" << node->pattern << " has attributes " << node->attrs << ")"; }); @@ -520,10 +569,12 @@ PatternSeq PatternSeq::dup() const { return ret; } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.PatternSeq") - .set_body_typed([](Array patterns, bool only_used_by) { - return PatternSeq(std::move(patterns), only_used_by); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.dpl.PatternSeq", [](Array patterns, bool only_used_by) { + return PatternSeq(std::move(patterns), only_used_by); + }); +}); RELAX_PATTERN_PRINTER_DEF(PatternSeqNode, [](auto p, auto node) { p->stream << "["; for (size_t i = 0; i < node->patterns.size(); ++i) { @@ -534,15 +585,14 @@ RELAX_PATTERN_PRINTER_DEF(PatternSeqNode, [](auto p, auto node) { p->stream << "]"; }); -TVM_FFI_REGISTER_GLOBAL("relax.dpl.used_by") - .set_body_typed([](PatternSeq lhs, PatternSeq rhs, int index) { - return lhs.UsedBy(rhs, index); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.dpl.only_used_by") - .set_body_typed([](PatternSeq lhs, PatternSeq rhs, int index) { - return lhs.OnlyUsedBy(rhs, index); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.dpl.used_by", + [](PatternSeq lhs, PatternSeq rhs, int index) { return lhs.UsedBy(rhs, index); }) + .def("relax.dpl.only_used_by", + [](PatternSeq lhs, PatternSeq rhs, int index) { return lhs.OnlyUsedBy(rhs, index); }); +}); PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index) { PatternSeq ret; @@ -652,28 +702,15 @@ DFPattern DFPattern::dup() const { return pattern; } -TVM_FFI_REGISTER_GLOBAL("relax.dpl.dup_pattern").set_body_typed([](DFPattern pattern) { - return pattern.dup(); -}); - -TVM_FFI_REGISTER_GLOBAL("relax.dpl.dup_seq").set_body_typed([](PatternSeq seq) { - return seq.dup(); -}); - -TVM_FFI_REGISTER_GLOBAL("relax.dpl.PatternContext").set_body_typed([](bool incre) { - return PatternContext(incre); -}); - -TVM_FFI_REGISTER_GLOBAL("relax.dpl.current_context").set_body_typed([] { - return PatternContext::Current(); -}); - -TVM_FFI_REGISTER_GLOBAL("relax.dpl.enter_context").set_body_typed([](const PatternContext& ctx) { - ctx.EnterWithScope(); -}); - -TVM_FFI_REGISTER_GLOBAL("relax.dpl.exit_context").set_body_typed([](const PatternContext& ctx) { - ctx.ExitWithScope(); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.dpl.dup_pattern", [](DFPattern pattern) { return pattern.dup(); }) + .def("relax.dpl.dup_seq", [](PatternSeq seq) { return seq.dup(); }) + .def("relax.dpl.PatternContext", [](bool incre) { return PatternContext(incre); }) + .def("relax.dpl.current_context", [] { return PatternContext::Current(); }) + .def("relax.dpl.enter_context", [](const PatternContext& ctx) { ctx.EnterWithScope(); }) + .def("relax.dpl.exit_context", [](const PatternContext& ctx) { ctx.ExitWithScope(); }); }); } // namespace relax diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc index 518ca7c0488c..264e88af3916 100644 --- a/src/relax/ir/emit_te.cc +++ b/src/relax/ir/emit_te.cc @@ -22,6 +22,7 @@ */ #include "./emit_te.h" +#include #include #include @@ -74,7 +75,10 @@ te::Tensor TETensor(Expr value, Map tir_var_map, std::string return te::PlaceholderOp(n).output(0); } -TVM_FFI_REGISTER_GLOBAL("relax.TETensor").set_body_typed(TETensor); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.TETensor", TETensor); +}); } // namespace relax } // namespace tvm diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 4585c45ef401..07dae6b29eb8 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -121,9 +121,12 @@ Call WithFields(Call call, Optional opt_op, Optional> opt_args TVM_REGISTER_NODE_TYPE(CallNode); -TVM_FFI_REGISTER_GLOBAL("relax.Call") - .set_body_typed([](Expr op, Array args, Attrs attrs, Array sinfo_args, - Span span) { return Call(op, args, attrs, sinfo_args, span); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.Call", + [](Expr op, Array args, Attrs attrs, Array sinfo_args, + Span span) { return Call(op, args, attrs, sinfo_args, span); }); +}); If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { ObjectPtr n = make_object(); @@ -156,10 +159,12 @@ If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branc TVM_REGISTER_NODE_TYPE(IfNode); -TVM_FFI_REGISTER_GLOBAL("relax.If") - .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch, Span span) { - return If(cond, true_branch, false_branch, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.If", [](Expr cond, Expr true_branch, Expr false_branch, Span span) { + return If(cond, true_branch, false_branch, span); + }); +}); Tuple::Tuple(tvm::Array fields, Span span) { Optional tuple_sinfo = [&]() -> Optional { @@ -183,8 +188,10 @@ Tuple::Tuple(tvm::Array fields, Span span) { TVM_REGISTER_NODE_TYPE(TupleNode); -TVM_FFI_REGISTER_GLOBAL("relax.Tuple").set_body_typed([](tvm::Array fields, Span span) { - return Tuple(fields, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.Tuple", + [](tvm::Array fields, Span span) { return Tuple(fields, span); }); }); Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional opt_span) { @@ -246,8 +253,11 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, TVM_REGISTER_NODE_TYPE(TupleGetItemNode); -TVM_FFI_REGISTER_GLOBAL("relax.TupleGetItem").set_body_typed([](Expr tuple, int index, Span span) { - return TupleGetItem(tuple, index, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.TupleGetItem", [](Expr tuple, int index, Span span) { + return TupleGetItem(tuple, index, span); + }); }); TVM_REGISTER_NODE_TYPE(ShapeExprNode); @@ -268,8 +278,10 @@ ShapeExpr::ShapeExpr(Array values, Span span) { data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.ShapeExpr").set_body_typed([](Array values, Span span) { - return ShapeExpr(values, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.ShapeExpr", + [](Array values, Span span) { return ShapeExpr(values, span); }); }); TVM_REGISTER_NODE_TYPE(VarNode); @@ -301,15 +313,15 @@ VarNode* Var::CopyOnWrite() { return static_cast(data_.get()); } -TVM_FFI_REGISTER_GLOBAL("relax.Var") - .set_body_typed([](String name_hint, Optional struct_info_annotation, Span span) { - return Var(name_hint, struct_info_annotation, span); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.VarFromId") - .set_body_typed([](Id vid, Optional struct_info_annotation, Span span) { - return Var(vid, struct_info_annotation, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.Var", [](String name_hint, Optional struct_info_annotation, + Span span) { return Var(name_hint, struct_info_annotation, span); }) + .def("relax.VarFromId", [](Id vid, Optional struct_info_annotation, Span span) { + return Var(vid, struct_info_annotation, span); + }); +}); TVM_REGISTER_NODE_TYPE(DataflowVarNode); @@ -322,15 +334,18 @@ DataflowVar::DataflowVar(Id vid, Optional struct_info_annotation, Sp data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.DataflowVar") - .set_body_typed([](String name_hint, Optional struct_info_annotation, Span span) { - return DataflowVar(name_hint, struct_info_annotation, span); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.DataflowVarFromId") - .set_body_typed([](Id vid, Optional struct_info_annotation, Span span) { - return DataflowVar(vid, struct_info_annotation, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.DataflowVar", + [](String name_hint, Optional struct_info_annotation, Span span) { + return DataflowVar(name_hint, struct_info_annotation, span); + }) + .def("relax.DataflowVarFromId", + [](Id vid, Optional struct_info_annotation, Span span) { + return DataflowVar(vid, struct_info_annotation, span); + }); +}); Constant::Constant(runtime::NDArray data, Optional struct_info_annotation, Span span) { ObjectPtr n = make_object(); @@ -355,12 +370,13 @@ Constant::Constant(runtime::NDArray data, Optional struct_info_annot TVM_REGISTER_NODE_TYPE(ConstantNode); -TVM_FFI_REGISTER_GLOBAL("relax.Constant") - .set_body_typed([](runtime::NDArray data, - Optional struct_info_annotation = std::nullopt, - Span span = Span()) { - return Constant(data, struct_info_annotation, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "relax.Constant", + [](runtime::NDArray data, Optional struct_info_annotation = std::nullopt, + Span span = Span()) { return Constant(data, struct_info_annotation, span); }); +}); PrimValue::PrimValue(PrimExpr value, Span span) { ObjectPtr n = make_object(); @@ -376,8 +392,10 @@ PrimValue PrimValue::Int64(int64_t value, Span span) { TVM_REGISTER_NODE_TYPE(PrimValueNode); -TVM_FFI_REGISTER_GLOBAL("relax.PrimValue").set_body_typed([](PrimExpr value, Span span) { - return PrimValue(value, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.PrimValue", + [](PrimExpr value, Span span) { return PrimValue(value, span); }); }); StringImm::StringImm(String value, Span span) { @@ -390,8 +408,10 @@ StringImm::StringImm(String value, Span span) { TVM_REGISTER_NODE_TYPE(StringImmNode); -TVM_FFI_REGISTER_GLOBAL("relax.StringImm").set_body_typed([](String value, Span span) { - return StringImm(value, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.StringImm", + [](String value, Span span) { return StringImm(value, span); }); }); DataTypeImm::DataTypeImm(DataType value, Span span) { @@ -404,8 +424,10 @@ DataTypeImm::DataTypeImm(DataType value, Span span) { TVM_REGISTER_NODE_TYPE(DataTypeImmNode); -TVM_FFI_REGISTER_GLOBAL("relax.DataTypeImm").set_body_typed([](DataType value, Span span) { - return DataTypeImm(value, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.DataTypeImm", + [](DataType value, Span span) { return DataTypeImm(value, span); }); }); TVM_REGISTER_NODE_TYPE(MatchCastNode); @@ -420,10 +442,13 @@ MatchCast::MatchCast(Var var, Expr value, StructInfo struct_info, Span span) { data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.MatchCast") - .set_body_typed([](Var var, Expr value, StructInfo struct_info, Span span) { - return MatchCast(var, value, struct_info, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.MatchCast", + [](Var var, Expr value, StructInfo struct_info, Span span) { + return MatchCast(var, value, struct_info, span); + }); +}); bool MatchCastNode::SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const { if (value->IsInstance()) { @@ -462,8 +487,11 @@ VarBinding::VarBinding(Var var, Expr value, Span span) { data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.VarBinding").set_body_typed([](Var var, Expr value, Span span) { - return VarBinding(var, value, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.VarBinding", [](Var var, Expr value, Span span) { + return VarBinding(var, value, span); + }); }); bool VarBindingNode::SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const { @@ -517,10 +545,12 @@ BindingBlockNode* BindingBlock::CopyOnWrite() { return static_cast(data_.get()); } -TVM_FFI_REGISTER_GLOBAL("relax.BindingBlock") - .set_body_typed([](Array bindings, Span span) { - return BindingBlock(bindings, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.BindingBlock", [](Array bindings, Span span) { + return BindingBlock(bindings, span); + }); +}); TVM_REGISTER_NODE_TYPE(DataflowBlockNode); @@ -531,10 +561,12 @@ DataflowBlock::DataflowBlock(Array bindings, Span span) { data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.DataflowBlock") - .set_body_typed([](Array bindings, Span span) { - return DataflowBlock(bindings, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.DataflowBlock", [](Array bindings, Span span) { + return DataflowBlock(bindings, span); + }); +}); TVM_REGISTER_NODE_TYPE(SeqExprNode); @@ -554,10 +586,12 @@ SeqExpr::SeqExpr(Array blocks, Expr body, Span span) { data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.SeqExpr") - .set_body_typed([](Array blocks, Expr body, Span span) { - return SeqExpr(blocks, body, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.SeqExpr", [](Array blocks, Expr body, Span span) { + return SeqExpr(blocks, body, span); + }); +}); TVM_REGISTER_NODE_TYPE(FunctionNode); @@ -629,11 +663,14 @@ Function::Function(Array params, Expr body, Optional ret_struct data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.Function") - .set_body_typed([](Array params, Expr body, Optional ret_struct_info, - bool is_pure, DictAttrs attrs, Span span) { - return Function(params, body, ret_struct_info, is_pure, attrs, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.Function", + [](Array params, Expr body, Optional ret_struct_info, + bool is_pure, DictAttrs attrs, Span span) { + return Function(params, body, ret_struct_info, is_pure, attrs, span); + }); +}); Function Function::CreateEmpty(Array params, StructInfo ret_struct_info, bool is_pure, DictAttrs attrs, Span span) { @@ -666,25 +703,32 @@ Function Function::CreateEmpty(Array params, StructInfo ret_struct_info, bo return Function(std::move(n)); } -TVM_FFI_REGISTER_GLOBAL("relax.FunctionCreateEmpty") - .set_body_typed([](Array params, StructInfo ret_struct_info, bool is_pure, DictAttrs attrs, - Span span) { - return Function::CreateEmpty(params, ret_struct_info, is_pure, attrs, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "relax.FunctionCreateEmpty", + [](Array params, StructInfo ret_struct_info, bool is_pure, DictAttrs attrs, Span span) { + return Function::CreateEmpty(params, ret_struct_info, is_pure, attrs, span); + }); +}); // Special opaque derivation function for ExternFunc // Take look at sinfo_args to figure out the return StructInfo. -TVM_FFI_REGISTER_GLOBAL("tvm.relax.struct_info.infer_by_sinfo_args") - .set_body_typed([](const Call& call, const BlockBuilder& ctx) -> StructInfo { - ICHECK(call->sinfo_args.defined()) << "sinfo_args field of CallNode should always be defined"; - if (call->sinfo_args.empty()) { - return ObjectStructInfo(); - } else if (call->sinfo_args.size() == 1) { - return call->sinfo_args[0]; - } else { - return TupleStructInfo(call->sinfo_args); - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tvm.relax.struct_info.infer_by_sinfo_args", + [](const Call& call, const BlockBuilder& ctx) -> StructInfo { + ICHECK(call->sinfo_args.defined()) + << "sinfo_args field of CallNode should always be defined"; + if (call->sinfo_args.empty()) { + return ObjectStructInfo(); + } else if (call->sinfo_args.size() == 1) { + return call->sinfo_args[0]; + } else { + return TupleStructInfo(call->sinfo_args); + } + }); +}); // Get the derive function. FuncStructInfo GetExternFuncStructInfo() { @@ -711,14 +755,17 @@ ExternFunc::ExternFunc(String global_symbol, StructInfo struct_info, Span span) data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.ExternFunc") - .set_body_typed([](String global_symbol, Optional struct_info, Span span) { - if (struct_info.defined()) { - return ExternFunc(global_symbol, struct_info.value(), span); - } else { - return ExternFunc(global_symbol, span); - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.ExternFunc", + [](String global_symbol, Optional struct_info, Span span) { + if (struct_info.defined()) { + return ExternFunc(global_symbol, struct_info.value(), span); + } else { + return ExternFunc(global_symbol, span); + } + }); +}); Expr GetShapeOf(const Expr& expr) { // default case, to be normalized. @@ -735,33 +782,31 @@ Expr GetShapeOf(const Expr& expr) { return call_shape_of; } -TVM_FFI_REGISTER_GLOBAL("relax.GetShapeOf").set_body_typed([](const Expr& expr) { - return GetShapeOf(expr); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.GetShapeOf", [](const Expr& expr) { return GetShapeOf(expr); }) + .def("relax.FuncWithAttr", + [](BaseFunc func, String key, ObjectRef value) -> Optional { + if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } + return std::nullopt; + }) + .def("relax.FuncWithAttrs", + [](BaseFunc func, Map attr_map) -> Optional { + if (func->IsInstance()) { + return WithAttrs(Downcast(std::move(func)), attr_map); + } + return std::nullopt; + }) + .def("relax.FuncWithoutAttr", [](BaseFunc func, String key) -> Optional { + if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } + return std::nullopt; + }); }); -TVM_FFI_REGISTER_GLOBAL("relax.FuncWithAttr") - .set_body_typed([](BaseFunc func, String key, ObjectRef value) -> Optional { - if (func->IsInstance()) { - return WithAttr(Downcast(std::move(func)), key, value); - } - return std::nullopt; - }); - -TVM_FFI_REGISTER_GLOBAL("relax.FuncWithAttrs") - .set_body_typed([](BaseFunc func, Map attr_map) -> Optional { - if (func->IsInstance()) { - return WithAttrs(Downcast(std::move(func)), attr_map); - } - return std::nullopt; - }); - -TVM_FFI_REGISTER_GLOBAL("relax.FuncWithoutAttr") - .set_body_typed([](BaseFunc func, String key) -> Optional { - if (func->IsInstance()) { - return WithoutAttr(Downcast(std::move(func)), key); - } - return std::nullopt; - }); - } // namespace relax } // namespace tvm diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 5e04453a1227..cdc9ab8d5082 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -24,6 +24,7 @@ * ExprMutator uses memoization and self return in order to amortize * the cost of using functional updates. */ +#include #include #include #include @@ -326,10 +327,12 @@ void PostOrderVisit(const Expr& e, std::function fvisit) { ExprApplyVisit(fvisit).VisitExpr(e); } -TVM_FFI_REGISTER_GLOBAL("relax.analysis.post_order_visit") - .set_body_typed([](Expr expr, ffi::Function f) { - PostOrderVisit(expr, [f](const Expr& n) { f(n); }); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.analysis.post_order_visit", [](Expr expr, ffi::Function f) { + PostOrderVisit(expr, [f](const Expr& n) { f(n); }); + }); +}); // ================== // ExprMutatorBase diff --git a/src/relax/ir/py_expr_functor.cc b/src/relax/ir/py_expr_functor.cc index 8a607191796d..6b32b00d1bee 100644 --- a/src/relax/ir/py_expr_functor.cc +++ b/src/relax/ir/py_expr_functor.cc @@ -548,154 +548,115 @@ class PyExprMutator : public ObjectRef { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyExprMutator, ObjectRef, PyExprMutatorNode); }; -TVM_FFI_REGISTER_GLOBAL("relax.MakePyExprVisitor").set_body_typed(PyExprVisitor::MakePyExprVisitor); - -TVM_FFI_REGISTER_GLOBAL("relax.PyExprVisitorVisitExpr") - .set_body_typed([](PyExprVisitor visitor, const Expr& expr) { visitor->VisitExpr(expr); }); - -TVM_FFI_REGISTER_GLOBAL("relax.PyExprVisitorVisitBinding") - .set_body_typed([](PyExprVisitor visitor, const Binding& binding) { - visitor->VisitBinding(binding); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.PyExprVisitorVisitBindingBlock") - .set_body_typed([](PyExprVisitor visitor, const BindingBlock& block) { - visitor->VisitBindingBlock(block); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.PyExprVisitorVisitVarDef") - .set_body_typed([](PyExprVisitor visitor, const Var& var) { visitor->VisitVarDef(var); }); - -TVM_FFI_REGISTER_GLOBAL("relax.ExprVisitorVisitExpr") - .set_body_typed([](PyExprVisitor visitor, const Expr& expr) { - visitor->ExprVisitor::VisitExpr(expr); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.ExprVisitorVisitBinding") - .set_body_typed([](PyExprVisitor visitor, const Binding& binding) { - if (const auto* ptr = binding.as()) { - visitor->ExprVisitor::VisitBinding_(ptr); - } else if (const auto* ptr = binding.as()) { - visitor->ExprVisitor::VisitBinding_(ptr); - } else { - LOG(FATAL) << "unreachable"; - } - }); - -TVM_FFI_REGISTER_GLOBAL("relax.ExprVisitorVisitBindingBlock") - .set_body_typed([](PyExprVisitor visitor, const BindingBlock& block) { - if (const auto* ptr = block.as()) { - visitor->ExprVisitor::VisitBindingBlock_(ptr); - } else if (const auto* ptr = block.as()) { - visitor->ExprVisitor::VisitBindingBlock_(ptr); - } else { - LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); - } - }); - -TVM_FFI_REGISTER_GLOBAL("relax.ExprVisitorVisitVarDef") - .set_body_typed([](PyExprVisitor visitor, const Var& var) { - if (const auto* node = var.as()) { - visitor->ExprVisitor::VisitVarDef_(node); - } else if (const auto* node = var.as()) { - visitor->ExprVisitor::VisitVarDef_(node); - } else { - LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); - } - }); - -TVM_FFI_REGISTER_GLOBAL("relax.ExprVisitorVisitSpan") - .set_body_typed([](PyExprVisitor visitor, const Span& span) { - visitor->ExprVisitor::VisitSpan(span); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.MakePyExprMutator").set_body_typed(PyExprMutator::MakePyExprMutator); - -TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorVisitExpr") - .set_body_typed([](PyExprMutator mutator, const Expr& expr) { - return mutator->VisitExpr(expr); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorVisitBinding") - .set_body_typed([](PyExprMutator mutator, const Binding& binding) { - mutator->VisitBinding(binding); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorVisitBindingBlock") - .set_body_typed([](PyExprMutator mutator, const BindingBlock& block) { - return mutator->VisitBindingBlock(block); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorVisitVarDef") - .set_body_typed([](PyExprMutator mutator, const Var& var) { - return mutator->VisitVarDef(var); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.ExprMutatorVisitExpr") - .set_body_typed([](PyExprMutator mutator, const Expr& expr) { - return mutator->ExprMutator::VisitExpr(expr); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.ExprMutatorVisitBinding") - .set_body_typed([](PyExprMutator mutator, const Binding& binding) { - if (const auto* ptr = binding.as()) { - return mutator->ExprMutator::VisitBinding_(ptr); - } else if (const auto* ptr = binding.as()) { - return mutator->ExprMutator::VisitBinding_(ptr); - } else { - LOG(FATAL) << "unreachable"; - } - }); - -TVM_FFI_REGISTER_GLOBAL("relax.ExprMutatorVisitBindingBlock") - .set_body_typed([](PyExprMutator mutator, const BindingBlock& block) { - if (const auto* node = block.as()) { - return mutator->ExprMutator::VisitBindingBlock_(node); - } else if (const auto* node = block.as()) { - return mutator->ExprMutator::VisitBindingBlock_(node); - } else { - LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); - } - }); - -TVM_FFI_REGISTER_GLOBAL("relax.ExprMutatorVisitVarDef") - .set_body_typed([](PyExprMutator mutator, const Var& var) { - if (const auto* node = var.as()) { - return mutator->ExprMutator::VisitVarDef_(node); - } else if (const auto* node = var.as()) { - return mutator->ExprMutator::VisitVarDef_(node); - } else { - LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); - } - }); - -TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorVisitExprPostOrder") - .set_body_typed([](PyExprMutator mutator, const Expr& expr) { - return mutator->VisitExprPostOrder(expr); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorVisitWithNewScope") - .set_body_typed([](PyExprMutator mutator, const Expr& expr) { - return mutator->VisitWithNewScope(expr); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorLookupBinding") - .set_body_typed([](PyExprMutator mutator, const Var& var) { - return mutator->LookupBinding(var); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorWithStructInfo") - .set_body_typed([](PyExprMutator mutator, Var var, StructInfo sinfo) { - return mutator->WithStructInfo(var, sinfo); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorSetVarRemap") - .set_body_typed([](PyExprMutator mutator, Id id, Var var) { - return mutator->var_remap_[id] = var; - }); - -TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorGetVarRemap") - .set_body_typed([](PyExprMutator mutator, Id id) { return mutator->var_remap_[id]; }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.MakePyExprVisitor", PyExprVisitor::MakePyExprVisitor) + .def("relax.PyExprVisitorVisitExpr", + [](PyExprVisitor visitor, const Expr& expr) { visitor->VisitExpr(expr); }) + .def("relax.PyExprVisitorVisitBinding", + [](PyExprVisitor visitor, const Binding& binding) { visitor->VisitBinding(binding); }) + .def("relax.PyExprVisitorVisitBindingBlock", + [](PyExprVisitor visitor, const BindingBlock& block) { + visitor->VisitBindingBlock(block); + }) + .def("relax.PyExprVisitorVisitVarDef", + [](PyExprVisitor visitor, const Var& var) { visitor->VisitVarDef(var); }) + .def("relax.ExprVisitorVisitExpr", + [](PyExprVisitor visitor, const Expr& expr) { visitor->ExprVisitor::VisitExpr(expr); }) + .def("relax.ExprVisitorVisitBinding", + [](PyExprVisitor visitor, const Binding& binding) { + if (const auto* ptr = binding.as()) { + visitor->ExprVisitor::VisitBinding_(ptr); + } else if (const auto* ptr = binding.as()) { + visitor->ExprVisitor::VisitBinding_(ptr); + } else { + LOG(FATAL) << "unreachable"; + } + }) + .def("relax.ExprVisitorVisitBindingBlock", + [](PyExprVisitor visitor, const BindingBlock& block) { + if (const auto* ptr = block.as()) { + visitor->ExprVisitor::VisitBindingBlock_(ptr); + } else if (const auto* ptr = block.as()) { + visitor->ExprVisitor::VisitBindingBlock_(ptr); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + }) + .def("relax.ExprVisitorVisitVarDef", + [](PyExprVisitor visitor, const Var& var) { + if (const auto* node = var.as()) { + visitor->ExprVisitor::VisitVarDef_(node); + } else if (const auto* node = var.as()) { + visitor->ExprVisitor::VisitVarDef_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); + } + }) + .def("relax.ExprVisitorVisitSpan", + [](PyExprVisitor visitor, const Span& span) { visitor->ExprVisitor::VisitSpan(span); }) + .def("relax.MakePyExprMutator", PyExprMutator::MakePyExprMutator) + .def("relax.PyExprMutatorVisitExpr", + [](PyExprMutator mutator, const Expr& expr) { return mutator->VisitExpr(expr); }) + .def("relax.PyExprMutatorVisitBinding", + [](PyExprMutator mutator, const Binding& binding) { mutator->VisitBinding(binding); }) + .def("relax.PyExprMutatorVisitBindingBlock", + [](PyExprMutator mutator, const BindingBlock& block) { + return mutator->VisitBindingBlock(block); + }) + .def("relax.PyExprMutatorVisitVarDef", + [](PyExprMutator mutator, const Var& var) { return mutator->VisitVarDef(var); }) + .def("relax.ExprMutatorVisitExpr", + [](PyExprMutator mutator, const Expr& expr) { + return mutator->ExprMutator::VisitExpr(expr); + }) + .def("relax.ExprMutatorVisitBinding", + [](PyExprMutator mutator, const Binding& binding) { + if (const auto* ptr = binding.as()) { + return mutator->ExprMutator::VisitBinding_(ptr); + } else if (const auto* ptr = binding.as()) { + return mutator->ExprMutator::VisitBinding_(ptr); + } else { + LOG(FATAL) << "unreachable"; + } + }) + .def("relax.ExprMutatorVisitBindingBlock", + [](PyExprMutator mutator, const BindingBlock& block) { + if (const auto* node = block.as()) { + return mutator->ExprMutator::VisitBindingBlock_(node); + } else if (const auto* node = block.as()) { + return mutator->ExprMutator::VisitBindingBlock_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + }) + .def("relax.ExprMutatorVisitVarDef", + [](PyExprMutator mutator, const Var& var) { + if (const auto* node = var.as()) { + return mutator->ExprMutator::VisitVarDef_(node); + } else if (const auto* node = var.as()) { + return mutator->ExprMutator::VisitVarDef_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); + } + }) + .def( + "relax.PyExprMutatorVisitExprPostOrder", + [](PyExprMutator mutator, const Expr& expr) { return mutator->VisitExprPostOrder(expr); }) + .def("relax.PyExprMutatorVisitWithNewScope", + [](PyExprMutator mutator, const Expr& expr) { return mutator->VisitWithNewScope(expr); }) + .def("relax.PyExprMutatorLookupBinding", + [](PyExprMutator mutator, const Var& var) { return mutator->LookupBinding(var); }) + .def("relax.PyExprMutatorWithStructInfo", + [](PyExprMutator mutator, Var var, StructInfo sinfo) { + return mutator->WithStructInfo(var, sinfo); + }) + .def("relax.PyExprMutatorSetVarRemap", + [](PyExprMutator mutator, Id id, Var var) { return mutator->var_remap_[id] = var; }) + .def("relax.PyExprMutatorGetVarRemap", + [](PyExprMutator mutator, Id id) { return mutator->var_remap_[id]; }); +}); TVM_FFI_STATIC_INIT_BLOCK({ PyExprVisitorNode::RegisterReflection(); diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc index 8599fc52e16b..db94af28ab07 100644 --- a/src/relax/ir/struct_info.cc +++ b/src/relax/ir/struct_info.cc @@ -47,8 +47,9 @@ ObjectStructInfo::ObjectStructInfo(Span span) { TVM_REGISTER_NODE_TYPE(ObjectStructInfoNode); -TVM_FFI_REGISTER_GLOBAL("relax.ObjectStructInfo").set_body_typed([](Span span) { - return ObjectStructInfo(span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.ObjectStructInfo", [](Span span) { return ObjectStructInfo(span); }); }); // Prim @@ -70,11 +71,14 @@ PrimStructInfo::PrimStructInfo(DataType dtype, Span span) { TVM_REGISTER_NODE_TYPE(PrimStructInfoNode); -TVM_FFI_REGISTER_GLOBAL("relax.PrimStructInfoFromDtype") - .set_body_typed([](DataType dtype, Span span) { return PrimStructInfo(dtype, span); }); - -TVM_FFI_REGISTER_GLOBAL("relax.PrimStructInfoFromValue") - .set_body_typed([](PrimExpr value, Span span) { return PrimStructInfo(value, span); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.PrimStructInfoFromDtype", + [](DataType dtype, Span span) { return PrimStructInfo(dtype, span); }) + .def("relax.PrimStructInfoFromValue", + [](PrimExpr value, Span span) { return PrimStructInfo(value, span); }); +}); // Shape ShapeStructInfo::ShapeStructInfo(Array values, Span span) { @@ -102,15 +106,18 @@ ShapeStructInfo::ShapeStructInfo(int ndim, Span span) { TVM_REGISTER_NODE_TYPE(ShapeStructInfoNode); -TVM_FFI_REGISTER_GLOBAL("relax.ShapeStructInfo") - .set_body_typed([](Optional> values, int ndim, Span span) { - if (values.defined()) { - CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify values and ndim"; - return ShapeStructInfo(values.value(), span); - } else { - return ShapeStructInfo(ndim, span); - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "relax.ShapeStructInfo", [](Optional> values, int ndim, Span span) { + if (values.defined()) { + CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify values and ndim"; + return ShapeStructInfo(values.value(), span); + } else { + return ShapeStructInfo(ndim, span); + } + }); +}); // Tensor TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, Optional vdevice, @@ -143,16 +150,18 @@ TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, Optional v TVM_REGISTER_NODE_TYPE(TensorStructInfoNode); -TVM_FFI_REGISTER_GLOBAL("relax.TensorStructInfo") - .set_body_typed([](Optional shape, Optional dtype, int ndim, VDevice vdevice, - Span span) { - if (shape.defined()) { - CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify shape and ndim"; - return TensorStructInfo(shape.value(), dtype.value_or(DataType::Void()), vdevice, span); - } else { - return TensorStructInfo(dtype.value_or(DataType::Void()), ndim, vdevice, span); - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.TensorStructInfo", [](Optional shape, Optional dtype, + int ndim, VDevice vdevice, Span span) { + if (shape.defined()) { + CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify shape and ndim"; + return TensorStructInfo(shape.value(), dtype.value_or(DataType::Void()), vdevice, span); + } else { + return TensorStructInfo(dtype.value_or(DataType::Void()), ndim, vdevice, span); + } + }); +}); // Tuple TupleStructInfo::TupleStructInfo(Array fields, Span span) { @@ -164,10 +173,12 @@ TupleStructInfo::TupleStructInfo(Array fields, Span span) { TVM_REGISTER_NODE_TYPE(TupleStructInfoNode); -TVM_FFI_REGISTER_GLOBAL("relax.TupleStructInfo") - .set_body_typed([](Array fields, Span span) { - return TupleStructInfo(fields, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.TupleStructInfo", [](Array fields, Span span) { + return TupleStructInfo(fields, span); + }); +}); // Func FuncStructInfo::FuncStructInfo(Array params, StructInfo ret, bool purity, Span span) { @@ -199,21 +210,24 @@ FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfo ret, bool purity, Span span TVM_REGISTER_NODE_TYPE(FuncStructInfoNode); -TVM_FFI_REGISTER_GLOBAL("relax.FuncStructInfo") - .set_body_typed([](Array params, StructInfo ret, bool purity, Span span) { - return FuncStructInfo(params, ret, purity, span); - }); - -TVM_FFI_REGISTER_GLOBAL("relax.FuncStructInfoOpaqueFunc") - .set_body_typed([](Optional ret, Optional derive_func, - bool purity, Span span) { - if (derive_func.defined()) { - ICHECK(!ret.defined()) << "ValueError: Cannot specify both ret and derive_func"; - return FuncStructInfo::OpaqueFunc(derive_func.value(), purity, span); - } else { - return FuncStructInfo::OpaqueFunc(ret.value_or(ObjectStructInfo()), purity, span); - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.FuncStructInfo", + [](Array params, StructInfo ret, bool purity, Span span) { + return FuncStructInfo(params, ret, purity, span); + }) + .def("relax.FuncStructInfoOpaqueFunc", + [](Optional ret, Optional derive_func, bool purity, + Span span) { + if (derive_func.defined()) { + ICHECK(!ret.defined()) << "ValueError: Cannot specify both ret and derive_func"; + return FuncStructInfo::OpaqueFunc(derive_func.value(), purity, span); + } else { + return FuncStructInfo::OpaqueFunc(ret.value_or(ObjectStructInfo()), purity, span); + } + }); +}); // Helper functions void UpdateStructInfo(Expr expr, StructInfo struct_info) { @@ -226,11 +240,12 @@ void UpdateStructInfo(Expr expr, StructInfo struct_info) { expr->struct_info_ = struct_info; } -TVM_FFI_REGISTER_GLOBAL("relax.UpdateStructInfo") - .set_body_typed([](Expr expr, StructInfo struct_info) { UpdateStructInfo(expr, struct_info); }); - -TVM_FFI_REGISTER_GLOBAL("ir.ExprStructInfo").set_body_typed([](Expr expr) { - return GetStructInfo(expr); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.UpdateStructInfo", + [](Expr expr, StructInfo struct_info) { UpdateStructInfo(expr, struct_info); }) + .def("ir.ExprStructInfo", [](Expr expr) { return GetStructInfo(expr); }); }); } // namespace relax diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc index 73e724b8958f..d2f6612454f0 100644 --- a/src/relax/ir/transform.cc +++ b/src/relax/ir/transform.cc @@ -167,15 +167,18 @@ Pass CreateFunctionPass(std::function TVM_REGISTER_NODE_TYPE(FunctionPassNode); -TVM_FFI_REGISTER_GLOBAL("relax.transform.MakeFunctionPass") - .set_body_typed( - [](ffi::TypedFunction, IRModule, PassContext)> pass_func, - PassInfo pass_info) { - auto wrapped_pass_func = [pass_func](Function func, IRModule mod, PassContext ctx) { - return pass_func(ffi::RValueRef(std::move(func)), mod, ctx); - }; - return FunctionPass(wrapped_pass_func, pass_info); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "relax.transform.MakeFunctionPass", + [](ffi::TypedFunction, IRModule, PassContext)> pass_func, + PassInfo pass_info) { + auto wrapped_pass_func = [pass_func](Function func, IRModule mod, PassContext ctx) { + return pass_func(ffi::RValueRef(std::move(func)), mod, ctx); + }; + return FunctionPass(wrapped_pass_func, pass_info); + }); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -390,16 +393,19 @@ Pass CreateDataflowBlockPass( TVM_REGISTER_NODE_TYPE(DataflowBlockPassNode); -TVM_FFI_REGISTER_GLOBAL("relax.transform.MakeDataflowBlockPass") - .set_body_typed( - [](ffi::TypedFunction, IRModule, PassContext)> - pass_func, - PassInfo pass_info) { - auto wrapped_pass_func = [pass_func](DataflowBlock func, IRModule mod, PassContext ctx) { - return pass_func(ffi::RValueRef(std::move(func)), mod, ctx); - }; - return DataflowBlockPass(wrapped_pass_func, pass_info); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "relax.transform.MakeDataflowBlockPass", + [](ffi::TypedFunction, IRModule, PassContext)> + pass_func, + PassInfo pass_info) { + auto wrapped_pass_func = [pass_func](DataflowBlock func, IRModule mod, PassContext ctx) { + return pass_func(ffi::RValueRef(std::move(func)), mod, ctx); + }; + return DataflowBlockPass(wrapped_pass_func, pass_info); + }); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { diff --git a/src/relax/ir/type.cc b/src/relax/ir/type.cc index 8a8aa460e80b..aa5be9294a05 100644 --- a/src/relax/ir/type.cc +++ b/src/relax/ir/type.cc @@ -44,8 +44,10 @@ ShapeType::ShapeType(int ndim, Span span) { data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("relax.ShapeType").set_body_typed([](int ndim, Span span) { - return ShapeType(ndim, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.ShapeType", + [](int ndim, Span span) { return ShapeType(ndim, span); }); }); ObjectType::ObjectType(Span span) { @@ -56,8 +58,9 @@ ObjectType::ObjectType(Span span) { TVM_REGISTER_NODE_TYPE(ObjectTypeNode); -TVM_FFI_REGISTER_GLOBAL("relax.ObjectType").set_body_typed([](Span span) { - return ObjectType(span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.ObjectType", [](Span span) { return ObjectType(span); }); }); TensorType::TensorType(int ndim, DataType dtype, Span span) { @@ -78,8 +81,11 @@ TensorType TensorType::CreateUnknownNDim(DataType dtype, Span span) { TVM_REGISTER_NODE_TYPE(TensorTypeNode); -TVM_FFI_REGISTER_GLOBAL("relax.TensorType").set_body_typed([](int ndim, DataType dtype, Span span) { - return TensorType(ndim, dtype, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.TensorType", [](int ndim, DataType dtype, Span span) { + return TensorType(ndim, dtype, span); + }); }); PackedFuncType::PackedFuncType(Span span) { @@ -90,8 +96,9 @@ PackedFuncType::PackedFuncType(Span span) { TVM_REGISTER_NODE_TYPE(PackedFuncTypeNode); -TVM_FFI_REGISTER_GLOBAL("relax.PackedFuncType").set_body_typed([](Span span) { - return PackedFuncType(span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.PackedFuncType", [](Span span) { return PackedFuncType(span); }); }); } // namespace relax diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc index c73cf672abd1..e28dd7e3977e 100644 --- a/src/relax/op/ccl/ccl.cc +++ b/src/relax/op/ccl/ccl.cc @@ -19,6 +19,8 @@ #include "ccl.h" +#include + #include namespace tvm { @@ -42,7 +44,10 @@ Expr allreduce(Expr x, String op_type, bool in_group) { return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.ccl.allreduce").set_body_typed(allreduce); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.ccl.allreduce", allreduce); +}); StructInfo InferStructInfoAllReduce(const Call& call, const BlockBuilder& ctx) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -69,7 +74,10 @@ Expr allgather(Expr x, int num_workers, bool in_group) { return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.ccl.allgather").set_body_typed(allgather); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.ccl.allgather", allgather); +}); StructInfo InferStructInfoAllGather(const Call& call, const BlockBuilder& ctx) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -100,8 +108,10 @@ Expr broadcast_from_worker0(Expr x) { return Call(op, {std::move(x)}, {}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.ccl.broadcast_from_worker0") - .set_body_typed(broadcast_from_worker0); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.ccl.broadcast_from_worker0", broadcast_from_worker0); +}); StructInfo InferStructInfoBroadcastFromZero(const Call& call, const BlockBuilder& ctx) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -127,7 +137,10 @@ Expr scatter_from_worker0(Expr data, int num_workers, int axis) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.ccl.scatter_from_worker0").set_body_typed(scatter_from_worker0); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.ccl.scatter_from_worker0", scatter_from_worker0); +}); StructInfo InferStructInfoScatter(const Call& call, const BlockBuilder& ctx) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/distributed/distributed.cc b/src/relax/op/distributed/distributed.cc index 6c1af6ec9bb4..97e8a12a5b72 100644 --- a/src/relax/op/distributed/distributed.cc +++ b/src/relax/op/distributed/distributed.cc @@ -24,6 +24,7 @@ #include "distributed.h" +#include #include #include @@ -50,7 +51,10 @@ Expr annotate_sharding(Expr input, distributed::DeviceMesh device_mesh, return Call(op, {std::move(input)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.dist.annotate_sharding").set_body_typed(annotate_sharding); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.dist.annotate_sharding", annotate_sharding); +}); StructInfo InferStructInfoAnnotateSharding(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[0]); @@ -75,7 +79,10 @@ Expr redistribute(Expr input, distributed::DeviceMesh device_mesh, return Call(op, {std::move(input)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.dist.redistribute").set_body_typed(redistribute); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.dist.redistribute", redistribute); +}); StructInfo InferDistStructInfoRedistribute(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); @@ -141,7 +148,10 @@ Expr MakeCallTIRLocalView(Expr func, Tuple args, return call; } -TVM_FFI_REGISTER_GLOBAL("relax.op.dist.call_tir_local_view").set_body_typed(MakeCallTIRLocalView); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.dist.call_tir_local_view", MakeCallTIRLocalView); +}); StructInfo InferStructInfoRtoS(const Call& call, const BlockBuilder& ctx) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -210,8 +220,11 @@ Expr redistribute_replica_to_shard(Expr input, int num_workers, int axis) { return Call(op, {std::move(input)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.dist.redistribute_replica_to_shard") - .set_body_typed(redistribute_replica_to_shard); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.dist.redistribute_replica_to_shard", + redistribute_replica_to_shard); +}); TVM_REGISTER_OP("relax.dist.redistribute_replica_to_shard") .set_num_inputs(1) diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc index 846f169e4f95..b20543f6af21 100644 --- a/src/relax/op/image/resize.cc +++ b/src/relax/op/image/resize.cc @@ -24,6 +24,8 @@ #include "resize.h" +#include + #include namespace tvm { @@ -52,7 +54,10 @@ Expr resize2d(Expr data, Expr size, Array roi, String layout, String m return Call(op, {std::move(data), std::move(size)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.image.resize2d").set_body_typed(resize2d); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.image.resize2d", resize2d); +}); StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1 && call->args.size() != 2) { diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index a7465db868fe..32badb28ebe9 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -24,6 +24,8 @@ #include "view.h" +#include + namespace tvm { namespace relax { @@ -40,7 +42,10 @@ Expr view(Expr x, Optional shape, Optional dtype, Optional rel }); } -TVM_FFI_REGISTER_GLOBAL("relax.op.memory.view").set_body_typed(view); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.memory.view", view); +}); StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 4) { @@ -289,8 +294,10 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { } } -TVM_FFI_REGISTER_GLOBAL("tvm.relax.struct_info.infer_view_sinfo") - .set_body_typed(InferStructInfoView); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tvm.relax.struct_info.infer_view_sinfo", InferStructInfoView); +}); Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { Expr data = call->args[0]; @@ -361,7 +368,10 @@ Expr ensure_zero_offset(const Expr& x) { return Call(op, {x}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.memory.ensure_zero_offset").set_body_typed(ensure_zero_offset); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.memory.ensure_zero_offset", ensure_zero_offset); +}); StructInfo InferStructInfoEnsureZeroOffset(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index e6f410424b01..e1d2e1f7bd59 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -19,6 +19,8 @@ #include "attention.h" +#include + #include namespace tvm { @@ -56,8 +58,12 @@ Expr attention_var_len(Expr query, Expr key, Expr value, Expr seqstart_q, Expr s {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.attention").set_body_typed(attention); -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.attention_var_len").set_body_typed(attention_var_len); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.op.nn.attention", attention) + .def("relax.op.nn.attention_var_len", attention_var_len); +}); StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index d86e6442fe42..d869eaf94c37 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -24,6 +24,8 @@ #include "convolution.h" +#include + #include namespace tvm { @@ -59,7 +61,10 @@ Expr conv1d(Expr data, Expr weight, Array strides, Array padding out_dtype.value_or(DataType::Void()), /*op_name=*/"relax.nn.conv1d"); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.conv1d").set_body_typed(conv1d); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.conv1d", conv1d); +}); StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -222,7 +227,10 @@ Expr conv2d(Expr data, Expr weight, Array strides, Array padding out_dtype.value_or(DataType::Void()), /*op_name=*/"relax.nn.conv2d"); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.conv2d").set_body_typed(conv2d); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.conv2d", conv2d); +}); StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -421,7 +429,10 @@ Expr conv3d(Expr data, Expr weight, Array strides, Array padding out_dtype.value_or(DataType::Void()), /*op_name=*/"relax.nn.conv3d"); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.conv3d").set_body_typed(conv3d); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.conv3d", conv3d); +}); StructInfo InferStructInfoConv3d(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -601,7 +612,10 @@ Expr conv1d_transpose(Expr data, Expr weight, Array strides, Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -738,7 +752,10 @@ Expr conv2d_transpose(Expr data, Expr weight, Array strides, Array input_sinfo = GetInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 64c849e547c2..279d6f08f80e 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -19,6 +19,8 @@ #include "nn.h" +#include + #include #include @@ -66,7 +68,10 @@ Expr leakyrelu(Expr data, double alpha) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.leakyrelu").set_body_typed(leakyrelu); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.leakyrelu", leakyrelu); +}); TVM_REGISTER_OP("relax.nn.leakyrelu") .set_num_inputs(1) @@ -87,7 +92,10 @@ Expr softplus(Expr data, double beta, double threshold) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.softplus").set_body_typed(softplus); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.softplus", softplus); +}); TVM_REGISTER_OP("relax.nn.softplus") .set_num_inputs(1) @@ -107,7 +115,10 @@ Expr prelu(Expr data, Expr alpha, int axis = 1) { return Call(op, {data, alpha}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.prelu").set_body_typed(prelu); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.prelu", prelu); +}); TVM_REGISTER_OP("relax.nn.prelu") .set_num_inputs(2) @@ -128,7 +139,10 @@ Expr softmax(Expr data, int axis) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.softmax").set_body_typed(softmax); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.softmax", softmax); +}); StructInfo InferStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -186,7 +200,10 @@ Expr log_softmax(Expr data, int axis) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.log_softmax").set_body_typed(log_softmax); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.log_softmax", log_softmax); +}); TVM_REGISTER_OP("relax.nn.log_softmax") .set_num_inputs(1) @@ -207,7 +224,10 @@ Expr pad(Expr data, Array pad_width, String pad_mode, double pad_value) return Call(op, {data}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.pad").set_body_typed(pad); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.pad", pad); +}); StructInfo InferStructInfoPad(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -250,7 +270,10 @@ Expr pixel_shuffle(Expr data, int upscale_factor) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.pixel_shuffle").set_body_typed(pixel_shuffle); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.pixel_shuffle", pixel_shuffle); +}); StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -399,7 +422,10 @@ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_ std::move(moving_var)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.batch_norm").set_body_typed(batch_norm); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.batch_norm", batch_norm); +}); StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -476,7 +502,10 @@ Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double ep return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.layer_norm").set_body_typed(layer_norm); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.layer_norm", layer_norm); +}); StructInfo InferStructInfoLayerNorm(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -544,7 +573,10 @@ Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_ax return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.group_norm").set_body_typed(group_norm); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.group_norm", group_norm); +}); StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) { Op op = Downcast(call->op); @@ -655,7 +687,10 @@ Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, Array(call->op); @@ -751,7 +786,10 @@ Expr rms_norm(Expr data, Expr weight, Array axes, double epsilon) { return Call(op, {std::move(data), std::move(weight)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.rms_norm").set_body_typed(rms_norm); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.rms_norm", rms_norm); +}); StructInfo InferStructInfoRMSNorm(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -810,7 +848,10 @@ Expr dropout(Expr data, double rate) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.dropout").set_body_typed(dropout); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.dropout", dropout); +}); StructInfo InferStructInfoDropout(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -878,8 +919,10 @@ Expr cross_entropy_with_logits(Expr predictions, Expr labels) { return Call(op, {std::move(predictions), std::move(labels)}, {}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.cross_entropy_with_logits") - .set_body_typed(cross_entropy_with_logits); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.cross_entropy_with_logits", cross_entropy_with_logits); +}); TVM_REGISTER_OP("relax.nn.cross_entropy_with_logits") .set_num_inputs(2) @@ -912,7 +955,10 @@ Expr nll_loss(Expr predictions, Expr targets, Optional weights, String red } } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.nll_loss").set_body_typed(nll_loss); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.nll_loss", nll_loss); +}); StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { if (call->args.size() < 2 || call->args.size() > 3) { diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index 3d684b82bf42..61bda019a55b 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -19,6 +19,8 @@ #include "pooling.h" +#include + #include #include @@ -71,7 +73,10 @@ Expr max_pool1d(Expr data, Array pool_size, Array strides, Array count_include_pad, layout, out_layout); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.max_pool1d").set_body_typed(max_pool1d); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.max_pool1d", max_pool1d); +}); StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -184,7 +189,10 @@ Expr max_pool2d(Expr data, Array pool_size, Array strides, Array count_include_pad, layout, out_layout); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.max_pool2d").set_body_typed(max_pool2d); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.max_pool2d", max_pool2d); +}); StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -323,7 +331,10 @@ Expr max_pool3d(Expr data, Array pool_size, Array strides, Array count_include_pad, layout, out_layout); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.max_pool3d").set_body_typed(max_pool3d); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.max_pool3d", max_pool3d); +}); StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -410,7 +421,10 @@ Expr avg_pool1d(Expr data, Array pool_size, Array strides, Array count_include_pad, layout, out_layout); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.avg_pool1d").set_body_typed(avg_pool1d); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.avg_pool1d", avg_pool1d); +}); TVM_REGISTER_OP("relax.nn.avg_pool1d") .set_num_inputs(1) @@ -429,7 +443,10 @@ Expr avg_pool2d(Expr data, Array pool_size, Array strides, Array count_include_pad, layout, out_layout); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.avg_pool2d").set_body_typed(avg_pool2d); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.avg_pool2d", avg_pool2d); +}); TVM_REGISTER_OP("relax.nn.avg_pool2d") .set_num_inputs(1) @@ -448,7 +465,10 @@ Expr avg_pool3d(Expr data, Array pool_size, Array strides, Array count_include_pad, layout, out_layout); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.avg_pool3d").set_body_typed(avg_pool3d); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.avg_pool3d", avg_pool3d); +}); TVM_REGISTER_OP("relax.nn.avg_pool3d") .set_num_inputs(1) @@ -479,7 +499,10 @@ Expr adaptive_avg_pool1d(Expr data, Optional> output_size, String return Call(op, {std::move(data)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.adaptive_avg_pool1d").set_body_typed(adaptive_avg_pool1d); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.adaptive_avg_pool1d", adaptive_avg_pool1d); +}); StructInfo InferStructInfoAdaptiveAvgPool1D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -562,7 +585,10 @@ Expr adaptive_avg_pool2d(Expr data, Optional> output_size, String return Call(op, {std::move(data)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.adaptive_avg_pool2d").set_body_typed(adaptive_avg_pool2d); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.adaptive_avg_pool2d", adaptive_avg_pool2d); +}); StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -661,7 +687,10 @@ Expr adaptive_avg_pool3d(Expr data, Optional> output_size, String return Call(op, {std::move(data)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nn.adaptive_avg_pool3d").set_body_typed(adaptive_avg_pool3d); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.adaptive_avg_pool3d", adaptive_avg_pool3d); +}); StructInfo InferStructInfoAdaptiveAvgPool3D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 5fa7feb90e42..4c55f98b3646 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include @@ -127,7 +128,10 @@ Expr MakeCallPurePacked(const Expr& callee, Array args, const Attrs& attrs return Call(op, call_args, attrs, sinfo_args); } -TVM_FFI_REGISTER_GLOBAL("relax.op.call_pure_packed").set_body_typed(MakeCallPurePacked); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.call_pure_packed", MakeCallPurePacked); +}); // call_inplace_packed @@ -246,7 +250,10 @@ Expr MakeCallInplacePacked(Expr func, Array args, Array inplace_i return Call(op, call_args, Attrs(attrs), sinfo_args); } -TVM_FFI_REGISTER_GLOBAL("relax.op.call_inplace_packed").set_body_typed(MakeCallInplacePacked); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.call_inplace_packed", MakeCallInplacePacked); +}); // call_tir @@ -608,7 +615,10 @@ Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, return call; } -TVM_FFI_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.call_tir", MakeCallTIR); +}); // call_tir_with_grad @@ -660,7 +670,10 @@ Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array out_sinf return call; } -TVM_FFI_REGISTER_GLOBAL("relax.op.call_tir_with_grad").set_body_typed(MakeCallTIRWithGrad); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.call_tir_with_grad", MakeCallTIRWithGrad); +}); // call_tir_inplace @@ -801,7 +814,10 @@ Expr MakeCallTIRInplace(Expr func, Tuple args, Array inplace_indices, return call; } -TVM_FFI_REGISTER_GLOBAL("relax.op.call_tir_inplace").set_body_typed(MakeCallTIRInplace); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.call_tir_inplace", MakeCallTIRInplace); +}); // call_dps_packed @@ -842,7 +858,10 @@ Expr MakeCallDPSPacked(Expr func, Tuple args, Array out_sinfo_ return Call(op, {func, args}, {}, {out_sinfo}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.call_dps_packed").set_body_typed(MakeCallDPSPacked); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.call_dps_packed", MakeCallDPSPacked); +}); // call builtin StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilder& ctx) { @@ -868,7 +887,10 @@ Expr MakeCallBuiltinWithCtx(Expr func, Tuple args, Array sinfo_args) return Call(op, {func, args}, Attrs(), sinfo_args); } -TVM_FFI_REGISTER_GLOBAL("relax.op.call_builtin_with_ctx").set_body_typed(MakeCallBuiltinWithCtx); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.call_builtin_with_ctx", MakeCallBuiltinWithCtx); +}); TVM_REGISTER_OP("relax.null_value") .set_num_inputs(0) @@ -880,7 +902,10 @@ Expr MakeCallNullValue() { return Call(op, {}, {}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.null_value").set_body_typed(MakeCallNullValue); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.null_value", MakeCallNullValue); +}); // print @@ -903,7 +928,10 @@ Expr MakePrint(Array vals, StringImm format) { return Call(op, params); } -TVM_FFI_REGISTER_GLOBAL("relax.op.print").set_body_typed(MakePrint); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.print", MakePrint); +}); // assert_op @@ -946,7 +974,10 @@ Expr MakeAssertOp(Expr condition, Array vals, StringImm format) { return Call(op, args); } -TVM_FFI_REGISTER_GLOBAL("relax.op.assert_op").set_body_typed(MakeAssertOp); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.assert_op", MakeAssertOp); +}); // make_closure @@ -962,7 +993,10 @@ Expr MakeClosure(Expr func, Tuple args) { return Call(op, {func, args}, {}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.make_closure").set_body_typed(MakeClosure); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.make_closure", MakeClosure); +}); // invoke_closure @@ -989,7 +1023,10 @@ Expr InvokeClosure(Expr closure, Tuple args, Array sinfo_args) { return Call(op, {closure, args}, {}, sinfo_args); } -TVM_FFI_REGISTER_GLOBAL("relax.op.invoke_closure").set_body_typed(InvokeClosure); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.invoke_closure", InvokeClosure); +}); // invoke_pure_closure @@ -1005,7 +1042,10 @@ Expr InvokePureClosure(Expr closure, Tuple args, Array sinfo_args) { return Call(op, {closure, args}, {}, sinfo_args); } -TVM_FFI_REGISTER_GLOBAL("relax.op.invoke_pure_closure").set_body_typed(InvokePureClosure); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.invoke_pure_closure", InvokePureClosure); +}); // shape_of @@ -1020,7 +1060,10 @@ Expr MakeShapeOf(Expr expr) { return Call(op, {expr}, {}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.shape_of").set_body_typed(MakeShapeOf); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.shape_of", MakeShapeOf); +}); // tensor_to_shape @@ -1054,7 +1097,10 @@ Expr MakeTensorToShape(Expr expr) { return Call(op, {expr}, {}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.tensor_to_shape").set_body_typed(MakeTensorToShape); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.tensor_to_shape", MakeTensorToShape); +}); // shape_to_tensor StructInfo ReturnShapeToTensorStructInfo(const Call& call, const BlockBuilder& ctx) { @@ -1078,7 +1124,10 @@ Expr MakeShapeToTensor(Expr expr) { return Call(op, {expr}, {}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.shape_to_tensor").set_body_typed(MakeShapeToTensor); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.shape_to_tensor", MakeShapeToTensor); +}); // alloc_tensor @@ -1115,7 +1164,10 @@ Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue runtime_device_ind return Call(op, {shape, dtype, runtime_device_index, storage_scope}, Attrs(), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor").set_body_typed(MakeAllocTensor); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.builtin.alloc_tensor", MakeAllocTensor); +}); // memory planning alloc_storage @@ -1140,7 +1192,10 @@ Expr MakeAllocStorage(Expr size, PrimValue virtual_device_index, StringImm stora return Call(op, {size, virtual_device_index, storage_scope, dtype}, Attrs(), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.memory.alloc_storage").set_body_typed(MakeAllocStorage); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.memory.alloc_storage", MakeAllocStorage); +}); // memory planning alloc_tensor @@ -1171,7 +1226,10 @@ Expr MakeMemAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm return Call(op, {storage, offset, shape, dtype}, Attrs(), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.memory.alloc_tensor").set_body_typed(MakeMemAllocTensor); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.memory.alloc_tensor", MakeMemAllocTensor); +}); // memory planning kill_storage @@ -1187,7 +1245,10 @@ Expr MakeMemKillStorage(Expr storage) { return Call(op, {storage}, {}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.memory.kill_storage").set_body_typed(MakeMemKillStorage); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.memory.kill_storage", MakeMemKillStorage); +}); // memory planning kill_tensor @@ -1203,7 +1264,10 @@ Expr MakeMemKillTensor(Expr tensor) { return Call(op, {tensor}, {}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.memory.kill_tensor").set_body_typed(MakeMemKillTensor); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.memory.kill_tensor", MakeMemKillTensor); +}); // vm alloc_storage @@ -1227,7 +1291,10 @@ Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm d return Call(op, {size, runtime_device_index, dtype, storage_scope}, Attrs(), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.vm.alloc_storage").set_body_typed(MakeVMAllocStorage); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.vm.alloc_storage", MakeVMAllocStorage); +}); // vm alloc_tensor @@ -1265,7 +1332,10 @@ Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm d return Call(op, {storage, offset, shape, dtype}, Attrs(), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.vm.alloc_tensor").set_body_typed(MakeVMAllocTensor); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.vm.alloc_tensor", MakeVMAllocTensor); +}); // vm kill_object @@ -1281,7 +1351,10 @@ Expr MakeVMKillObject(Expr obj) { return Call(op, {std::move(obj)}, Attrs(), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.vm.kill_object").set_body_typed(MakeVMKillObject); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.vm.kill_object", MakeVMKillObject); +}); // vm call_tir_dyn @@ -1299,7 +1372,10 @@ Expr MakeCallTIRDyn(Expr func, Tuple args) { return Call(op, {func, args}, Attrs(), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.vm.call_tir_dyn").set_body_typed(MakeCallTIRDyn); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.vm.call_tir_dyn", MakeCallTIRDyn); +}); // builtin stop_lift_params StructInfo InferStructInfoStopLiftParams(const Call& call, const BlockBuilder& ctx) { @@ -1317,7 +1393,10 @@ Expr MakeStopLiftParams(Expr x) { return Call(op, {x}, Attrs(), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.builtin.stop_lift_params").set_body_typed(MakeStopLiftParams); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.builtin.stop_lift_params", MakeStopLiftParams); +}); // to_vdevice TVM_REGISTER_NODE_TYPE(ToVDeviceAttrs); @@ -1348,7 +1427,10 @@ Expr MakeToVDevice(Expr data, VDevice dst_vdev) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.to_vdevice").set_body_typed(MakeToVDevice); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.to_vdevice", MakeToVDevice); +}); // hint_on_device TVM_REGISTER_NODE_TYPE(HintOnDeviceAttrs); @@ -1375,7 +1457,10 @@ Expr MakeHintOnDevice(Expr data, Device device) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.hint_on_device").set_body_typed(MakeHintOnDevice); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.hint_on_device", MakeHintOnDevice); +}); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index 94f927027197..3bd943e11f15 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -25,6 +25,7 @@ #include "create.h" #include +#include #include #include @@ -61,7 +62,10 @@ Expr full(Variant> shape, Expr fill_value, Optionalargs.size() != 2) { @@ -103,7 +107,10 @@ Expr full_like(Expr x, Expr fill_value, Optional dtype) { return Call(op, {std::move(x), std::move(fill_value)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.full_like").set_body_typed(full_like); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.full_like", full_like); +}); StructInfo InferStructInfoFullLike(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -181,8 +188,10 @@ Expr ones_like(Expr x, Optional dtype) { return Call(op, {std::move(x)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.ones").set_body_typed(ones); -TVM_FFI_REGISTER_GLOBAL("relax.op.ones_like").set_body_typed(ones_like); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.ones", ones).def("relax.op.ones_like", ones_like); +}); TVM_REGISTER_OP("relax.ones") .set_attrs_type() @@ -216,8 +225,10 @@ Expr zeros_like(Expr x, Optional dtype) { return Call(op, {std::move(x)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.zeros").set_body_typed(zeros); -TVM_FFI_REGISTER_GLOBAL("relax.op.zeros_like").set_body_typed(zeros_like); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.zeros", zeros).def("relax.op.zeros_like", zeros_like); +}); TVM_REGISTER_OP("relax.zeros") .set_attrs_type() @@ -249,8 +260,10 @@ Expr eye_like(Expr x, PrimValue k, Optional dtype) { return Call(op, {std::move(x), std::move(k)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.eye").set_body_typed(eye); -TVM_FFI_REGISTER_GLOBAL("relax.op.eye_like").set_body_typed(eye_like); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.eye", eye).def("relax.op.eye_like", eye_like); +}); StructInfo InferStructInfoEye(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 3) { @@ -326,7 +339,10 @@ Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype) { return Call(op, {std::move(start), std::move(stop), std::move(step)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.arange").set_body_typed(arange); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.arange", arange); +}); StructInfo InferStructInfoArange(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 3) { @@ -380,7 +396,10 @@ Expr hamming_window(PrimValue window_size, PrimValue periodic, PrimValue alpha, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.hamming_window").set_body_typed(hamming_window); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.hamming_window", hamming_window); +}); StructInfo InferStructInfoHammingWindow(const Call& call, const BlockBuilder& ctx) { DataType dtype = call->attrs.as()->dtype; @@ -438,8 +457,12 @@ Expr triu(Expr x, Expr k) { Expr triu(Expr x, int k) { return triu(x, relax::PrimValue::Int64(k)); } -TVM_FFI_REGISTER_GLOBAL("relax.op.tril").set_body_typed(static_cast(tril)); -TVM_FFI_REGISTER_GLOBAL("relax.op.triu").set_body_typed(static_cast(triu)); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.op.tril", static_cast(tril)) + .def("relax.op.triu", static_cast(triu)); +}); StructInfo InferStructInfoTrilTriu(const Call& call, const BlockBuilder& ctx) { auto [data_sinfo, offset] = GetArgStructInfo(call, ctx); diff --git a/src/relax/op/tensor/datatype.cc b/src/relax/op/tensor/datatype.cc index 0be8dfce6604..7129fe58e97f 100644 --- a/src/relax/op/tensor/datatype.cc +++ b/src/relax/op/tensor/datatype.cc @@ -24,6 +24,8 @@ #include "datatype.h" +#include + #include namespace tvm { @@ -45,7 +47,10 @@ Expr astype(Expr x, DataType dtype) { return Call(op, {std::move(x)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.astype").set_body_typed(astype); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.astype", astype); +}); StructInfo InferStructInfoAstype(const Call& call, const BlockBuilder& ctx) { TensorStructInfo sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -75,7 +80,10 @@ Expr MakeWrapParam(Expr data, DataType dtype) { return Call(op, {std::move(data)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.wrap_param").set_body_typed(MakeWrapParam); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.wrap_param", MakeWrapParam); +}); StructInfo InferStructInfoWrapParam(const Call& call, const BlockBuilder& ctx) { TensorStructInfo sinfo = GetUnaryInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/tensor/grad.cc b/src/relax/op/tensor/grad.cc index 506a50c0e7a8..1343fd1a390b 100644 --- a/src/relax/op/tensor/grad.cc +++ b/src/relax/op/tensor/grad.cc @@ -24,6 +24,8 @@ #include "grad.h" +#include + #include namespace tvm { @@ -35,7 +37,10 @@ Expr no_grad(Expr input) { return Call(op, {std::move(input)}, {}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.grad.no_grad").set_body_typed(no_grad); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.grad.no_grad", no_grad); +}); StructInfo InferStructInfoNoGrad(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[0]); @@ -53,7 +58,10 @@ Expr start_checkpoint(Expr input) { return Call(op, {std::move(input)}, {}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.grad.start_checkpoint").set_body_typed(start_checkpoint); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.grad.start_checkpoint", start_checkpoint); +}); StructInfo InferStructInfoStartCheckpoint(const Call& call, const BlockBuilder& ctx) { if (!call->args[0].as()) { @@ -75,7 +83,10 @@ Expr end_checkpoint(Expr input) { return Call(op, {std::move(input)}, {}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.grad.end_checkpoint").set_body_typed(end_checkpoint); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.grad.end_checkpoint", end_checkpoint); +}); StructInfo InferStructInfoEndCheckpoint(const Call& call, const BlockBuilder& ctx) { if (!call->args[0].as()) { @@ -110,7 +121,10 @@ Expr nll_loss_backward(Expr output_grad, Expr predictions, Expr targets, Optiona } } -TVM_FFI_REGISTER_GLOBAL("relax.op.grad.nll_loss_backward").set_body_typed(nll_loss_backward); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.grad.nll_loss_backward", nll_loss_backward); +}); StructInfo InferStructInfoNLLLossBackward(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[1]); @@ -144,7 +158,10 @@ Expr max_pool2d_backward(Expr output_grad, Expr data, Array pool_size, return Call(op, {std::move(output_grad), std::move(data)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.grad.max_pool2d_backward").set_body_typed(max_pool2d_backward); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.grad.max_pool2d_backward", max_pool2d_backward); +}); StructInfo InferStructInfoMaxPool2DBackward(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[1]); @@ -176,7 +193,10 @@ Expr avg_pool2d_backward(Expr output_grad, Expr data, Array pool_size, return Call(op, {std::move(output_grad), std::move(data)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.grad.avg_pool2d_backward").set_body_typed(avg_pool2d_backward); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.grad.avg_pool2d_backward", avg_pool2d_backward); +}); StructInfo InferStructInfoAvgPool2DBackward(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[1]); @@ -201,7 +221,10 @@ Expr take_backward(Expr output_grad, Expr x, Expr indices, Optional axi return Call(op, {std::move(output_grad), std::move(x), std::move(indices)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.grad.take_backward").set_body_typed(take_backward); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.grad.take_backward", take_backward); +}); StructInfo InferStructInfoTakeBackward(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[1]); diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index bbee9a502dd1..535492984465 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -24,6 +24,7 @@ #include "index.h" +#include #include #include @@ -53,7 +54,10 @@ Expr take(Expr x, Expr indices, Optional axis, String mode) { return Call(op, {std::move(x), std::move(indices)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.take").set_body_typed(take); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.take", take); +}); StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) { CheckNumArguments(call, ctx); @@ -177,7 +181,10 @@ Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, Optional strid return call; } -TVM_FFI_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.strided_slice", strided_slice); +}); /* \brief Helper function to unpack a relax::Tuple * @@ -485,7 +492,10 @@ Expr dynamic_strided_slice(Expr x, // return Call(op, {std::move(x), std::move(begin), std::move(end), std::move(strides)}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.dynamic_strided_slice").set_body_typed(dynamic_strided_slice); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.dynamic_strided_slice", dynamic_strided_slice); +}); StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder& ctx) { const auto* data_sinfo = GetStructInfoAs(call->args[0]); diff --git a/src/relax/op/tensor/linear_algebra.cc b/src/relax/op/tensor/linear_algebra.cc index fd2d85a0533a..9ce22084345f 100644 --- a/src/relax/op/tensor/linear_algebra.cc +++ b/src/relax/op/tensor/linear_algebra.cc @@ -24,6 +24,7 @@ #include "linear_algebra.h" +#include #include #include @@ -49,7 +50,10 @@ Expr matmul(Expr x1, Expr x2, Optional out_dtype) { return Call(op, {std::move(x1), std::move(x2)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.matmul").set_body_typed(matmul); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.matmul", matmul); +}); StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -181,7 +185,10 @@ Expr einsum(Expr operands, String subscripts) { return Call(op, {std::move(operands)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.einsum").set_body_typed(einsum); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.einsum", einsum); +}); StructInfo InferStructInfoEinsum(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { @@ -263,7 +270,10 @@ Expr outer(Expr x1, Expr x2) { return Call(op, {std::move(x1), std::move(x2)}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.outer").set_body_typed(outer); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.outer", outer); +}); StructInfo InferStructInfoOuter(const Call& call, const BlockBuilder& ctx) { auto input_sinfo = GetInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 3eb29a82e3d1..fd1fcc67f7eb 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -24,6 +24,8 @@ #include "manipulate.h" +#include + #include #include #include @@ -62,7 +64,10 @@ Expr broadcast_to(Expr x, Expr shape) { return Call(op, {std::move(x), std::move(shape)}, Attrs(), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.broadcast_to").set_body_typed(broadcast_to); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.broadcast_to", broadcast_to); +}); StructInfo InferStructInfoBroadcastTo(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { @@ -145,7 +150,10 @@ Expr concat(Expr tensors, Optional axis) { return Call(op, {std::move(tensors)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.concat").set_body_typed(concat); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.concat", concat); +}); Optional> CheckConcatOutputShape(const Call& call, const BlockBuilder& ctx, const std::vector>& shape_values, @@ -361,7 +369,10 @@ Expr expand_dims(Expr x, Array axis) { return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.expand_dims").set_body_typed(expand_dims); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.expand_dims", expand_dims); +}); StructInfo InferStructInfoExpandDims(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -467,7 +478,10 @@ Expr flatten(Expr x) { return Call(op, {std::move(x)}, {}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.flatten").set_body_typed(flatten); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.flatten", flatten); +}); StructInfo InferStructInfoFlatten(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -502,7 +516,10 @@ Expr index_tensor(Expr first, Expr tensors) { return Call(op, {std::move(first), std::move(tensors)}, Attrs(), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.index_tensor").set_body_typed(index_tensor); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.index_tensor", index_tensor); +}); StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { @@ -656,7 +673,10 @@ Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_v return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.layout_transform").set_body_typed(layout_transform); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.layout_transform", layout_transform); +}); StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -723,7 +743,10 @@ Expr permute_dims(Expr x, Optional> axes) { return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.permute_dims").set_body_typed(permute_dims); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.permute_dims", permute_dims); +}); bool IsIdentityPermutation(const std::vector& permutation) { for (int i = 0; i < static_cast(permutation.size()); ++i) { @@ -931,7 +954,10 @@ Expr reshape(Expr x, Variant> shape) { return Call(op, {std::move(x), std::move(shape_in_expr)}, Attrs(), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.reshape").set_body_typed(reshape); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.reshape", reshape); +}); StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { @@ -1018,7 +1044,10 @@ Expr split(Expr x, Variant> indices_or_sections, int axis) return Call(op, {std::move(x)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.split").set_body_typed(split); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.split", split); +}); StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -1171,7 +1200,10 @@ Expr squeeze(Expr x, Optional> axis) { return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.squeeze").set_body_typed(squeeze); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.squeeze", squeeze); +}); StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -1371,7 +1403,10 @@ Expr stack(Expr tensors, Optional axis) { return Call(op, {std::move(tensors)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.stack").set_body_typed(stack); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.stack", stack); +}); Optional> CheckStackOutputShape(const Call& call, const BlockBuilder& ctx, const std::vector>& shape_values, @@ -1575,7 +1610,10 @@ Expr collapse_sum_like(Expr data, Expr collapse_target) { return Call(op, {std::move(data), std::move(collapse_target)}, Attrs(), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.collapse_sum_like").set_body_typed(collapse_sum_like); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.collapse_sum_like", collapse_sum_like); +}); StructInfo InferStructInfoCollapseSumLike(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -1621,7 +1659,10 @@ Expr collapse_sum_to(Expr data, Expr shape) { return Call(op, {std::move(data), std::move(shape)}, Attrs(), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.collapse_sum_to").set_body_typed(collapse_sum_to); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.collapse_sum_to", collapse_sum_to); +}); StructInfo InferStructInfoCollapseSumTo(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { @@ -1676,7 +1717,10 @@ Expr repeat(Expr data, int repeats, Optional axis) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.repeat").set_body_typed(repeat); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.repeat", repeat); +}); StructInfo InferStructInfoRepeat(const Call& call, const BlockBuilder& ctx) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); @@ -1741,7 +1785,10 @@ Expr tile(Expr data, Array repeats) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.tile").set_body_typed(tile); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.tile", tile); +}); StructInfo InferStructInfoTile(const Call& call, const BlockBuilder& ctx) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); @@ -1804,7 +1851,10 @@ Expr flip(Expr data, Integer axis) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.flip").set_body_typed(flip); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.flip", flip); +}); StructInfo InferStructInfoFlip(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { @@ -1841,7 +1891,10 @@ Expr gather_elements(Expr data, Expr indices, int axis) { return Call(op, {data, indices}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.gather_elements").set_body_typed(gather_elements); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.gather_elements", gather_elements); +}); StructInfo InferStructInfoGatherElements(const Call& call, const BlockBuilder& ctx) { const auto* data_sinfo = GetStructInfoAs(call->args[0]); @@ -1910,7 +1963,10 @@ Expr gather_nd(Expr data, Expr indices, int batch_dims) { return Call(op, {data, indices}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.gather_nd").set_body_typed(gather_nd); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.gather_nd", gather_nd); +}); StructInfo InferStructInfoGatherND(const Call& call, const BlockBuilder& ctx) { const auto* data_sinfo = GetStructInfoAs(call->args[0]); @@ -2004,7 +2060,10 @@ Expr index_put(Expr data, Expr indices, Expr values, bool accumulate) { return Call(op, {data, indices, values}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.index_put").set_body_typed(index_put); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.index_put", index_put); +}); StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { const auto* data_sinfo = GetStructInfoAs(call->args[0]); @@ -2127,7 +2186,10 @@ Expr meshgrid(Expr tensors, Optional indexing) { return Call(op, {std::move(tensors)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.meshgrid").set_body_typed(meshgrid); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.meshgrid", meshgrid); +}); StructInfo InferStructInfoMeshgrid(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { @@ -2231,7 +2293,10 @@ Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, String re return Call(op, {data, indices, updates}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.scatter_elements").set_body_typed(scatter_elements); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.scatter_elements", scatter_elements); +}); StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder& ctx) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); @@ -2345,7 +2410,10 @@ Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction) { return Call(op, {data, indices, updates}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.scatter_nd").set_body_typed(scatter_nd); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.scatter_nd", scatter_nd); +}); StructInfo InferStructInfoScatterND(const Call& call, const BlockBuilder& ctx) { // `call->args` contains: [data, indices, updates] @@ -2479,7 +2547,10 @@ Expr slice_scatter(Expr input, Expr src, int axis, PrimValue start, PrimValue en return Call(op, {input, src, start, end, step}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.slice_scatter").set_body_typed(slice_scatter); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.slice_scatter", slice_scatter); +}); StructInfo InferStructInfoSliceScatter(const Call& call, const BlockBuilder& ctx) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); @@ -2643,7 +2714,10 @@ Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, i return Call(op, {indices, on_value, off_value}, Attrs(attrs), {}); } // namespace relax -TVM_FFI_REGISTER_GLOBAL("relax.op.one_hot").set_body_typed(one_hot); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.one_hot", one_hot); +}); StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) { TensorStructInfo indices_sinfo = GetInputTensorStructInfo(call, 0, ctx); diff --git a/src/relax/op/tensor/qdq.cc b/src/relax/op/tensor/qdq.cc index 2d6006edcf75..baf622f26969 100644 --- a/src/relax/op/tensor/qdq.cc +++ b/src/relax/op/tensor/qdq.cc @@ -24,6 +24,8 @@ #include "qdq.h" +#include + #include #include "../../transform/utils.h" @@ -46,7 +48,10 @@ Expr quantize(Expr data, Expr scale, Expr zero_point, int axis, DataType out_dty return Call(op, {std::move(data), std::move(scale), std::move(zero_point)}, Attrs(attrs)); } -TVM_FFI_REGISTER_GLOBAL("relax.op.quantize").set_body_typed(quantize); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.quantize", quantize); +}); StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); @@ -129,7 +134,10 @@ Expr dequantize(Expr data, Expr scale, Expr zero_point, int axis, DataType out_d return Call(op, {std::move(data), std::move(scale), std::move(zero_point)}, Attrs(attrs)); } -TVM_FFI_REGISTER_GLOBAL("relax.op.dequantize").set_body_typed(dequantize); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.dequantize", dequantize); +}); StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); diff --git a/src/relax/op/tensor/sampling.cc b/src/relax/op/tensor/sampling.cc index 00554da115c9..b5016c75b4ab 100644 --- a/src/relax/op/tensor/sampling.cc +++ b/src/relax/op/tensor/sampling.cc @@ -24,6 +24,7 @@ #include "sampling.h" +#include #include #include @@ -45,8 +46,10 @@ Expr multinomial_from_uniform(Expr prob, Expr uniform_sample, Expr sample_indice Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.multinomial_from_uniform") - .set_body_typed(multinomial_from_uniform); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.multinomial_from_uniform", multinomial_from_uniform); +}); StructInfo InferStructInfoMultinomialFromUniform(const Call& call, const BlockBuilder& ctx) { CheckNumArguments(call, ctx); diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc index 3e0236fc28e5..60b9a4ce9d8b 100644 --- a/src/relax/op/tensor/search.cc +++ b/src/relax/op/tensor/search.cc @@ -24,6 +24,8 @@ #include "search.h" +#include + #include #include @@ -46,7 +48,10 @@ Expr bucketize(Expr input_tensor, Expr boundaries, bool out_int32, bool right) { return Call(op, {std::move(input_tensor), std::move(boundaries)}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.bucketize").set_body_typed(bucketize); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.bucketize", bucketize); +}); StructInfo InferStructInfoBucketize(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -89,7 +94,10 @@ Expr where(Expr condition, Expr x1, Expr x2) { return Call(op, {std::move(condition), std::move(x1), std::move(x2)}, Attrs(), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.where").set_body_typed(where); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.where", where); +}); StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index e321b326d24e..0268a5bfc768 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -24,6 +24,8 @@ #include "set.h" +#include + #include #include #include @@ -46,7 +48,10 @@ Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue return_i return call; } -TVM_FFI_REGISTER_GLOBAL("relax.op.unique").set_body_typed(unique); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.unique", unique); +}); StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = Downcast(call->args[0]->struct_info_); @@ -144,7 +149,10 @@ Expr nonzero(Expr x) { return Call(op, {std::move(x)}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.nonzero").set_body_typed(nonzero); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nonzero", nonzero); +}); StructInfo InferStructInfoNonzero(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); diff --git a/src/relax/op/tensor/sorting.cc b/src/relax/op/tensor/sorting.cc index 73a82090acc7..5f210e2d92ce 100644 --- a/src/relax/op/tensor/sorting.cc +++ b/src/relax/op/tensor/sorting.cc @@ -24,6 +24,8 @@ #include "sorting.h" +#include + #include namespace tvm { @@ -47,7 +49,10 @@ Expr sort(Expr data, int axis, bool descending) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.sort").set_body_typed(sort); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.sort", sort); +}); StructInfo InferStructInfoSort(const Call& call, const BlockBuilder& ctx) { return GetUnaryInputTensorStructInfo(call, ctx); @@ -73,7 +78,10 @@ Expr argsort(Expr data, int axis, bool descending, DataType dtype) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.argsort").set_body_typed(argsort); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.argsort", argsort); +}); StructInfo InferStructInfoArgsort(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -107,7 +115,10 @@ Expr topk(Expr data, int k, int axis, String ret_type, bool largest, DataType dt return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.topk").set_body_typed(topk); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.topk", topk); +}); StructInfo InferStructInfoTopK(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc index 7efb25fb0f48..56920c896acd 100644 --- a/src/relax/op/tensor/statistical.cc +++ b/src/relax/op/tensor/statistical.cc @@ -24,6 +24,8 @@ #include "statistical.h" +#include + #include #include @@ -191,7 +193,10 @@ Expr cumprod(Expr data, Optional axis, Optional dtype, Bool e return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.cumprod").set_body_typed(cumprod); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.cumprod", cumprod); +}); TVM_REGISTER_OP("relax.cumprod") .set_attrs_type() @@ -211,7 +216,10 @@ Expr cumsum(Expr data, Optional axis, Optional dtype, Bool ex return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.cumsum").set_body_typed(cumsum); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.cumsum", cumsum); +}); TVM_REGISTER_OP("relax.cumsum") .set_attrs_type() diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc index 91a6e8d0ae04..e57638c41a08 100644 --- a/src/relax/op/tensor/ternary.cc +++ b/src/relax/op/tensor/ternary.cc @@ -24,6 +24,8 @@ #include "ternary.h" +#include + namespace tvm { namespace relax { @@ -143,7 +145,10 @@ Expr ewise_fma(Expr x1, Expr x2, Expr x3) { return Call(op, {x1, x2, x3}, Attrs(), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.ewise_fma").set_body_typed(ewise_fma); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.ewise_fma", ewise_fma); +}); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc index 828a91dde21d..2482032f5e26 100644 --- a/src/relax/op/tensor/unary.cc +++ b/src/relax/op/tensor/unary.cc @@ -24,6 +24,8 @@ #include "unary.h" +#include + #include namespace tvm { @@ -85,7 +87,10 @@ Expr clip(Expr x, Expr min, Expr max) { return Call(op, {std::move(x), std::move(min), std::move(max)}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.clip").set_body_typed(clip); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.clip", clip); +}); /***************** Check operators *****************/ diff --git a/src/relax/testing/transform.cc b/src/relax/testing/transform.cc index c4e41d5afc1f..ea1e4dc1aa36 100644 --- a/src/relax/testing/transform.cc +++ b/src/relax/testing/transform.cc @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include @@ -35,8 +36,10 @@ tvm::transform::Pass ApplyEmptyCppMutator() { "relax.testing.ApplyEmptyCppMutator", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.testing.transform.ApplyEmptyCppMutator") - .set_body_typed(ApplyEmptyCppMutator); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.testing.transform.ApplyEmptyCppMutator", ApplyEmptyCppMutator); +}); } // namespace testing } // namespace relax diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc index cb44339f1969..d72e2ee14044 100644 --- a/src/relax/training/utils.cc +++ b/src/relax/training/utils.cc @@ -24,6 +24,7 @@ #include "utils.h" +#include #include #include #include @@ -215,7 +216,10 @@ Pass AppendLoss(String func_name, Function loss_function, int num_backbone_outpu /*required=*/{}); } -TVM_FFI_REGISTER_GLOBAL("relax.training.AppendLoss").set_body_typed(AppendLoss); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.training.AppendLoss", AppendLoss); +}); } // namespace transform diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index 46dc803018ea..d376fa8a751c 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -22,6 +22,7 @@ * \brief Re-order `matmul(matmul(A,B), x)` to `matmul(A, matmul(B,x))` */ +#include #include #include #include @@ -213,7 +214,10 @@ Pass AdjustMatmulOrder() { return CreateFunctionPass(pass_func, 1, "AdjustMatmulOrder", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.AdjustMatmulOrder").set_body_typed(AdjustMatmulOrder); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.AdjustMatmulOrder", AdjustMatmulOrder); +}); } // namespace transform } // namespace relax diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc index 763a009a24b2..084eec69c1c1 100644 --- a/src/relax/transform/allocate_workspace.cc +++ b/src/relax/transform/allocate_workspace.cc @@ -23,6 +23,7 @@ * satisfy their temporary storage requirement. */ +#include #include #include #include @@ -201,7 +202,10 @@ Pass AllocateWorkspace() { return CreateModulePass(pass_func, 0, "AllocateWorkspace", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.AllocateWorkspace").set_body_typed(AllocateWorkspace); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.AllocateWorkspace", AllocateWorkspace); +}); } // namespace transform } // namespace tvm diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index 63521e4e8fe1..5b8067d37e30 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -24,6 +24,7 @@ * true. */ #include +#include #include #include #include @@ -438,7 +439,10 @@ Pass AlterOpImpl(const Map& op_impl_map, /*required=*/{}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.AlterOpImpl").set_body_typed(AlterOpImpl); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.AlterOpImpl", AlterOpImpl); +}); } // namespace transform } // namespace relax diff --git a/src/relax/transform/annotate_tir_op_pattern.cc b/src/relax/transform/annotate_tir_op_pattern.cc index e2b0fc2c2877..e623643b5ea1 100644 --- a/src/relax/transform/annotate_tir_op_pattern.cc +++ b/src/relax/transform/annotate_tir_op_pattern.cc @@ -22,6 +22,7 @@ * \brief Annotate Op Pattern for TIR functions. It is a pass works on TIR PrimFuncs, * but they are needed for relax fusion. So we put them in the relax namespace. */ +#include #include #include #include @@ -47,8 +48,10 @@ Pass AnnotateTIROpPattern() { return tir::transform::CreatePrimFuncPass(pass_func, 0, "AnnotateTIROpPattern", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.AnnotateTIROpPattern") - .set_body_typed(AnnotateTIROpPattern); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.AnnotateTIROpPattern", AnnotateTIROpPattern); +}); } // namespace transform diff --git a/src/relax/transform/attach_attr_layout_free_buffers.cc b/src/relax/transform/attach_attr_layout_free_buffers.cc index cef74890806d..1bab46f0cdc4 100644 --- a/src/relax/transform/attach_attr_layout_free_buffers.cc +++ b/src/relax/transform/attach_attr_layout_free_buffers.cc @@ -21,6 +21,7 @@ * \brief Attach layout_free_buffers for layout-free buffers. */ +#include #include #include #include @@ -105,8 +106,10 @@ Pass AttachAttrLayoutFreeBuffers() { return tvm::transform::Sequential({pass, DeadCodeElimination()}, "AttachAttrLayoutFreeBuffers"); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.AttachAttrLayoutFreeBuffers") - .set_body_typed(AttachAttrLayoutFreeBuffers); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.AttachAttrLayoutFreeBuffers", AttachAttrLayoutFreeBuffers); +}); } // namespace transform } // namespace relax } // namespace tvm diff --git a/src/relax/transform/attach_global_symbol.cc b/src/relax/transform/attach_global_symbol.cc index 905d2bcd838d..73ae2c36d2cf 100644 --- a/src/relax/transform/attach_global_symbol.cc +++ b/src/relax/transform/attach_global_symbol.cc @@ -21,6 +21,7 @@ * \brief Attach global_symbol to Relax functions and TIR Primfuncs for codegen. */ +#include #include #include #include @@ -79,7 +80,10 @@ Pass AttachGlobalSymbol() { return CreateModulePass(pass_func, 0, "AttachGlobalSymbol", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.AttachGlobalSymbol").set_body_typed(AttachGlobalSymbol); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.AttachGlobalSymbol", AttachGlobalSymbol); +}); } // namespace transform } // namespace relax } // namespace tvm diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc index 2a5c6f525d50..582e500dabf7 100644 --- a/src/relax/transform/bind_params.cc +++ b/src/relax/transform/bind_params.cc @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include #include @@ -196,7 +197,10 @@ IRModule BindParam(IRModule m, String func_name, Map bind_ return GetRef(new_module); } -TVM_FFI_REGISTER_GLOBAL("relax.FunctionBindParams").set_body_typed(FunctionBindParams); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.FunctionBindParams", FunctionBindParams); +}); namespace transform { @@ -207,7 +211,10 @@ Pass BindParams(String func_name, Map params) { return CreateModulePass(pass_func, 0, "BindParams", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.BindParams").set_body_typed(BindParams); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.BindParams", BindParams); +}); } // namespace transform diff --git a/src/relax/transform/bind_symbolic_vars.cc b/src/relax/transform/bind_symbolic_vars.cc index 49af21c10755..9bac3114d4f3 100644 --- a/src/relax/transform/bind_symbolic_vars.cc +++ b/src/relax/transform/bind_symbolic_vars.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include @@ -148,7 +149,10 @@ IRModule ModuleBindSymbolicVars(IRModule mod, Map binding_m } } // namespace -TVM_FFI_REGISTER_GLOBAL("relax.FunctionBindSymbolicVars").set_body_typed(FunctionBindSymbolicVars); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.FunctionBindSymbolicVars", FunctionBindSymbolicVars); +}); namespace transform { @@ -170,7 +174,10 @@ Pass BindSymbolicVars(Map binding_map, Optional fun return tvm::transform::CreateModulePass(pass_func, 1, "relax.BindSymbolicVars", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.BindSymbolicVars").set_body_typed(BindSymbolicVars); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.BindSymbolicVars", BindSymbolicVars); +}); } // namespace transform } // namespace relax diff --git a/src/relax/transform/bundle_model_params.cc b/src/relax/transform/bundle_model_params.cc index 982e1ac0c323..2d2a493deb4f 100644 --- a/src/relax/transform/bundle_model_params.cc +++ b/src/relax/transform/bundle_model_params.cc @@ -22,6 +22,7 @@ * \brief Lift local functions into global functions. */ +#include #include #include #include @@ -115,7 +116,10 @@ Pass BundleModelParams(Optional param_tuple_name) { return CreateModulePass(pass_func, 1, "BundleModelParams", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.BundleModelParams").set_body_typed(BundleModelParams); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.BundleModelParams", BundleModelParams); +}); } // namespace transform } // namespace relax diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index 25b4abadc7ff..b05229fa560a 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -21,6 +21,7 @@ * \brief Perform explicit tensor allocation for call_tir, * call_tir_inplace, and call_dps_packed. */ +#include #include #include #include @@ -183,7 +184,10 @@ Pass CallTIRRewrite() { /*required=*/{}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.CallTIRRewrite").set_body_typed(CallTIRRewrite); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.CallTIRRewrite", CallTIRRewrite); +}); } // namespace transform diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index ecbb9e77518e..9238c5eeaeb6 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -24,6 +24,7 @@ * Ideally should be used before constant folding and eliminating unused bindings. */ +#include #include #include #include @@ -591,8 +592,10 @@ Pass CanonicalizeBindings() { "CanonicalizeBindings"); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.CanonicalizeBindings") - .set_body_typed(CanonicalizeBindings); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.CanonicalizeBindings", CanonicalizeBindings); +}); } // namespace transform diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index 620186320342..ef827d1dd510 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -18,6 +18,7 @@ */ #include +#include #include #include #include @@ -387,8 +388,10 @@ Pass CombineParallelMatmul(FCheck check) { /*required=*/{}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.CombineParallelMatmul") - .set_body_typed(CombineParallelMatmul); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.CombineParallelMatmul", CombineParallelMatmul); +}); } // namespace transform diff --git a/src/relax/transform/compute_prim_value.cc b/src/relax/transform/compute_prim_value.cc index e6db2eb73f3a..9cffabc31d78 100644 --- a/src/relax/transform/compute_prim_value.cc +++ b/src/relax/transform/compute_prim_value.cc @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include #include @@ -86,7 +87,10 @@ Pass ComputePrimValue() { return CreateModulePass(pass_func, 0, "ComputePrimValue", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.ComputePrimValue").set_body_typed(ComputePrimValue); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.ComputePrimValue", ComputePrimValue); +}); } // namespace transform diff --git a/src/relax/transform/convert_dataflow.cc b/src/relax/transform/convert_dataflow.cc index c359afdebc28..0afef159d9c4 100644 --- a/src/relax/transform/convert_dataflow.cc +++ b/src/relax/transform/convert_dataflow.cc @@ -23,6 +23,7 @@ * dataflow into dataflow blocks. */ +#include #include #include #include @@ -159,7 +160,10 @@ Pass ConvertToDataflow(int min_size) { return tvm::transform::Sequential({pass, CanonicalizeBindings()}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.ConvertToDataflow").set_body_typed(ConvertToDataflow); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.ConvertToDataflow", ConvertToDataflow); +}); } // namespace transform diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index 0cdaa0b192c7..57329e32a895 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -21,6 +21,7 @@ * \brief Automatic layout conversion pass, especially for axis swapping. */ +#include #include #include #include @@ -350,7 +351,10 @@ Pass ConvertLayout(Map> desired_layouts) { return CreateDataflowBlockPass(pass_func, 0, "ConvertLayout", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.ConvertLayout").set_body_typed(ConvertLayout); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.ConvertLayout", ConvertLayout); +}); } // namespace transform } // namespace relax diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index d23cf47f75eb..3a5e490f5aa7 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -1021,23 +1021,26 @@ Array> DataflowInplaceAnalysis(const DataflowBlock& bl } // these are exposed only for testing -TVM_FFI_REGISTER_GLOBAL("relax.testing.transform.DataflowLivenessAnalysis") - .set_body_typed(DataflowLivenessAnalysis); -TVM_FFI_REGISTER_GLOBAL("relax.testing.transform.DataflowAliasAnalysis") - .set_body_typed(DataflowAliasAnalysis); -TVM_FFI_REGISTER_GLOBAL("relax.testing.transform.DataflowInplaceAnalysis") - .set_body_typed(DataflowInplaceAnalysis); -TVM_FFI_REGISTER_GLOBAL("relax.testing.transform.SingleInplaceCall") - .set_body_typed([](const IRModule& mod, const Call& call, - const Array& inplace_indices) -> Array { - ModuleInplaceTransformer transformer(mod); - auto ret_call = transformer.CreateInplaceCall(call, inplace_indices); - return Array{ret_call, transformer.CurrentMod()}; - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.testing.transform.DataflowLivenessAnalysis", DataflowLivenessAnalysis) + .def("relax.testing.transform.DataflowAliasAnalysis", DataflowAliasAnalysis) + .def("relax.testing.transform.DataflowInplaceAnalysis", DataflowInplaceAnalysis) + .def("relax.testing.transform.SingleInplaceCall", + [](const IRModule& mod, const Call& call, + const Array& inplace_indices) -> Array { + ModuleInplaceTransformer transformer(mod); + auto ret_call = transformer.CreateInplaceCall(call, inplace_indices); + return Array{ret_call, transformer.CurrentMod()}; + }); +}); // actually exposed -TVM_FFI_REGISTER_GLOBAL("relax.transform.DataflowUseInplaceCalls") - .set_body_typed(DataflowUseInplaceCalls); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.DataflowUseInplaceCalls", DataflowUseInplaceCalls); +}); } // namespace transform } // namespace relax diff --git a/src/relax/transform/dead_code_elimination.cc b/src/relax/transform/dead_code_elimination.cc index 7de1da329f88..129b61eb5483 100644 --- a/src/relax/transform/dead_code_elimination.cc +++ b/src/relax/transform/dead_code_elimination.cc @@ -32,6 +32,7 @@ * Any binding blocks that are left empty will be removed by the normalizer. */ +#include #include #include #include @@ -140,7 +141,10 @@ Pass DeadCodeElimination(Array entry_functions) { return CreateModulePass(pass_func, 1, "DeadCodeElimination", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.DeadCodeElimination").set_body_typed(DeadCodeElimination); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.DeadCodeElimination", DeadCodeElimination); +}); } // namespace transform } // namespace relax diff --git a/src/relax/transform/decompose_ops.cc b/src/relax/transform/decompose_ops.cc index eec27f3b7888..589e8f40c572 100644 --- a/src/relax/transform/decompose_ops.cc +++ b/src/relax/transform/decompose_ops.cc @@ -19,6 +19,7 @@ /*! \file src/relax/transform/decompose_ops.cc */ +#include #include #include #include @@ -250,11 +251,12 @@ Pass DecomposeOpsForTraining(Optional func_name) { } } -TVM_FFI_REGISTER_GLOBAL("relax.transform.DecomposeOpsForInference") - .set_body_typed(DecomposeOpsForInference); - -TVM_FFI_REGISTER_GLOBAL("relax.transform.DecomposeOpsForTraining") - .set_body_typed(DecomposeOpsForTraining); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.transform.DecomposeOpsForInference", DecomposeOpsForInference) + .def("relax.transform.DecomposeOpsForTraining", DecomposeOpsForTraining); +}); } // namespace transform } // namespace relax diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc index 8a5ce1db04de..790d6c30ef16 100644 --- a/src/relax/transform/eliminate_common_subexpr.cc +++ b/src/relax/transform/eliminate_common_subexpr.cc @@ -24,6 +24,7 @@ * * Currently it removes common subexpressions within a Function. */ +#include #include #include #include @@ -221,8 +222,10 @@ Pass EliminateCommonSubexpr(bool call_only) { return CreateFunctionPass(pass_func, 1, "EliminateCommonSubexpr", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.EliminateCommonSubexpr") - .set_body_typed(EliminateCommonSubexpr); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.EliminateCommonSubexpr", EliminateCommonSubexpr); +}); } // namespace transform diff --git a/src/relax/transform/expand_matmul_of_sum.cc b/src/relax/transform/expand_matmul_of_sum.cc index d7bf2dd95ffb..518cd2572450 100644 --- a/src/relax/transform/expand_matmul_of_sum.cc +++ b/src/relax/transform/expand_matmul_of_sum.cc @@ -22,6 +22,7 @@ * \brief Expand `matmul(x, A+B)` to `matmul(x, A) + matmul(x,B)` */ +#include #include #include #include @@ -104,7 +105,10 @@ Pass ExpandMatmulOfSum() { return CreateFunctionPass(pass_func, 1, "ExpandMatmulOfSum", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.ExpandMatmulOfSum").set_body_typed(ExpandMatmulOfSum); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.ExpandMatmulOfSum", ExpandMatmulOfSum); +}); } // namespace transform } // namespace relax diff --git a/src/relax/transform/expand_tuple_arguments.cc b/src/relax/transform/expand_tuple_arguments.cc index ec5818476a57..68921c22e322 100644 --- a/src/relax/transform/expand_tuple_arguments.cc +++ b/src/relax/transform/expand_tuple_arguments.cc @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include @@ -178,8 +179,10 @@ Pass ExpandTupleArguments() { "ExpandTupleArguments"); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.ExpandTupleArguments") - .set_body_typed(ExpandTupleArguments); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.ExpandTupleArguments", ExpandTupleArguments); +}); } // namespace transform diff --git a/src/relax/transform/few_shot_tuning.cc b/src/relax/transform/few_shot_tuning.cc index 4ccf6c25abc8..a2ba60fc6529 100644 --- a/src/relax/transform/few_shot_tuning.cc +++ b/src/relax/transform/few_shot_tuning.cc @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include "../../meta_schedule/utils.h" @@ -172,7 +173,10 @@ Pass FewShotTuning(int valid_count, bool benchmark) { /*required=*/{}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.FewShotTuning").set_body_typed(FewShotTuning); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.FewShotTuning", FewShotTuning); +}); } // namespace transform } // namespace relax diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index 8f8cb0b18cb5..7c5b5812b897 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include #include @@ -327,7 +328,10 @@ Pass FoldConstant() { return CreateFunctionPass(pass_func, 0, "FoldConstant", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.FoldConstant").set_body_typed(FoldConstant); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.FoldConstant", FoldConstant); +}); } // namespace transform diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index e21c8a30a0e9..42fb14aededf 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -1401,11 +1401,15 @@ FusionPattern::FusionPattern(String name, DFPattern pattern, } TVM_REGISTER_NODE_TYPE(FusionPatternNode); -TVM_FFI_REGISTER_GLOBAL("relax.transform.FusionPattern") - .set_body_typed([](String name, DFPattern pattern, Map annotation_patterns, - Optional check, Optional attrs_getter) { - return FusionPattern(name, pattern, annotation_patterns, check, attrs_getter); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "relax.transform.FusionPattern", + [](String name, DFPattern pattern, Map annotation_patterns, + Optional check, Optional attrs_getter) { + return FusionPattern(name, pattern, annotation_patterns, check, attrs_getter); + }); +}); PatternCheckContext::PatternCheckContext(Expr matched_expr, Map annotated_expr, Map matched_bindings, @@ -1435,7 +1439,10 @@ Pass FuseOps(int fuse_opt_level) { /*required=*/{}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.FuseOps", FuseOps); +}); Pass FuseOpsByPattern(const tvm::Array& patterns, bool bind_constants, bool annotate_codegen, const Array& entry_function_names) { @@ -1450,7 +1457,10 @@ Pass FuseOpsByPattern(const tvm::Array& patterns, bool bind_const /*required=*/{}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.FuseOpsByPattern").set_body_typed(FuseOpsByPattern); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.FuseOpsByPattern", FuseOpsByPattern); +}); } // namespace transform diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index a774d24a6359..5de5579ec045 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include @@ -1267,7 +1268,10 @@ Pass FuseTIR() { "FuseTIR"); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.FuseTIR").set_body_typed(FuseTIR); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.FuseTIR", FuseTIR); +}); } // namespace transform diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index 9998b6da93f3..e42302b18776 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -25,6 +25,7 @@ * with respect to the only return value of the function, which needs to be scalar. */ +#include #include #include #include @@ -787,7 +788,10 @@ Pass Gradient(String func_name, Optional> require_grads, int target_i /*required=*/{}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.Gradient").set_body_typed(Gradient); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.Gradient", Gradient); +}); } // namespace transform diff --git a/src/relax/transform/inline_functions.cc b/src/relax/transform/inline_functions.cc index c0d69ee810f0..e69b24cee16d 100644 --- a/src/relax/transform/inline_functions.cc +++ b/src/relax/transform/inline_functions.cc @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include #include @@ -164,7 +165,10 @@ Function FunctionInlineFunctions(Function func, return Downcast(mutator(std::move(func))); } -TVM_FFI_REGISTER_GLOBAL("relax.FunctionInlineFunctions").set_body_typed(FunctionInlineFunctions); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.FunctionInlineFunctions", FunctionInlineFunctions); +}); namespace transform { @@ -219,8 +223,10 @@ Pass InlinePrivateFunctions() { return tvm::transform::CreateModulePass(pass_func, 0, "InlinePrivateFunctions", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.InlinePrivateFunctions") - .set_body_typed(InlinePrivateFunctions); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.InlinePrivateFunctions", InlinePrivateFunctions); +}); } // namespace transform diff --git a/src/relax/transform/kill_after_last_use.cc b/src/relax/transform/kill_after_last_use.cc index 730f65f701ba..50cfc970ab34 100644 --- a/src/relax/transform/kill_after_last_use.cc +++ b/src/relax/transform/kill_after_last_use.cc @@ -21,6 +21,7 @@ * \brief Kill storage/tensor objects after last use, if not already killed */ #include +#include #include #include #include @@ -265,7 +266,10 @@ Pass KillAfterLastUse() { return CreateFunctionPass(pass_func, /*opt_level=*/0, "KillAfterLastUse", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.KillAfterLastUse").set_body_typed(KillAfterLastUse); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.KillAfterLastUse", KillAfterLastUse); +}); } // namespace transform } // namespace relax diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc index fb0a6a6ee16a..fdd5cbe8acc7 100644 --- a/src/relax/transform/lambda_lift.cc +++ b/src/relax/transform/lambda_lift.cc @@ -22,6 +22,7 @@ * \brief Lift local functions into global functions. */ +#include #include #include #include @@ -495,7 +496,10 @@ Pass LambdaLift() { return tvm::transform::CreateModulePass(pass_func, 1, "LambdaLift", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.LambdaLift").set_body_typed(LambdaLift); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.LambdaLift", LambdaLift); +}); } // namespace transform } // namespace relax diff --git a/src/relax/transform/lazy_transform_params.cc b/src/relax/transform/lazy_transform_params.cc index 32f63e1e141b..9ec990c8aad7 100644 --- a/src/relax/transform/lazy_transform_params.cc +++ b/src/relax/transform/lazy_transform_params.cc @@ -19,6 +19,7 @@ /*! \file src/relax/transform/lazy_transform_params.cc */ +#include #include #include #include @@ -259,7 +260,10 @@ Pass LazyGetInput() { /*required=*/{}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.LazyGetInput").set_body_typed(LazyGetInput); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.LazyGetInput", LazyGetInput); +}); Pass LazySetOutput() { auto pass_func = [](Function func, IRModule, PassContext) -> Function { @@ -274,7 +278,10 @@ Pass LazySetOutput() { /*required=*/{}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.LazySetOutput").set_body_typed(LazySetOutput); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.LazySetOutput", LazySetOutput); +}); } // namespace transform } // namespace relax diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index c2a1ab47146d..57df1f5d3dfe 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -23,6 +23,7 @@ * with corresponding low-level TIR PrimFuncs. */ +#include #include #include #include @@ -404,7 +405,10 @@ Pass LegalizeOps(Optional> cmap, bool enable_warning) /*required=*/{}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.LegalizeOps").set_body_typed(LegalizeOps); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.LegalizeOps", LegalizeOps); +}); } // namespace transform diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 9013737df5e4..5c4c88c6a361 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -22,6 +22,7 @@ * \brief Lift local functions into global functions. */ +#include #include #include #include @@ -867,7 +868,10 @@ Pass LiftTransformParams(Variant> shared_transform) { "LiftTransformParams"); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.LiftTransformParams").set_body_typed(LiftTransformParams); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.LiftTransformParams", LiftTransformParams); +}); } // namespace transform } // namespace relax diff --git a/src/relax/transform/lower_alloc_tensor.cc b/src/relax/transform/lower_alloc_tensor.cc index 3bdbfd0b94a9..dda7828e4103 100644 --- a/src/relax/transform/lower_alloc_tensor.cc +++ b/src/relax/transform/lower_alloc_tensor.cc @@ -20,6 +20,7 @@ * \file src/relax/transform/lower_alloc_tensor.cc * \brief Lower any relax.builtin.alloc_tensor remaining after static planning */ +#include #include #include @@ -99,7 +100,10 @@ Pass LowerAllocTensor() { return CreateFunctionPass(pass_func, /*opt_level=*/0, "LowerAllocTensor", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.LowerAllocTensor").set_body_typed(LowerAllocTensor); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.LowerAllocTensor", LowerAllocTensor); +}); } // namespace transform } // namespace relax diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc index ffeddd08c401..a12699b672fd 100644 --- a/src/relax/transform/merge_composite_functions.cc +++ b/src/relax/transform/merge_composite_functions.cc @@ -54,6 +54,7 @@ * is important since the dependency relation is transitive. */ +#include #include #include #include @@ -421,8 +422,10 @@ Pass MergeCompositeFunctions() { /*required=*/{}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.MergeCompositeFunctions") - .set_body_typed(MergeCompositeFunctions); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.MergeCompositeFunctions", MergeCompositeFunctions); +}); } // namespace transform diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc index 08e5a100ab22..d6654b33a76e 100644 --- a/src/relax/transform/meta_schedule.cc +++ b/src/relax/transform/meta_schedule.cc @@ -21,6 +21,7 @@ * \file tvm/relax/transform/meta_schedule.cc * \brief Pass for meta_schedule tuning */ +#include #include #include #include @@ -175,11 +176,13 @@ Pass MetaScheduleTuneTIR(String work_dir, Integer max_trials_global) { /*traceable*/ true); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.MetaScheduleApplyDatabase") - .set_body_typed(MetaScheduleApplyDatabase); -TVM_FFI_REGISTER_GLOBAL("relax.transform.MetaScheduleTuneIRMod") - .set_body_typed(MetaScheduleTuneIRMod); -TVM_FFI_REGISTER_GLOBAL("relax.transform.MetaScheduleTuneTIR").set_body_typed(MetaScheduleTuneTIR); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.transform.MetaScheduleApplyDatabase", MetaScheduleApplyDatabase) + .def("relax.transform.MetaScheduleTuneIRMod", MetaScheduleTuneIRMod) + .def("relax.transform.MetaScheduleTuneTIR", MetaScheduleTuneTIR); +}); } // namespace transform } // namespace relax } // namespace tvm diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index d997ea040d60..36573ac8e2eb 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -24,6 +24,7 @@ * available. */ +#include #include #include #include @@ -279,7 +280,10 @@ Pass Normalize() { return CreateFunctionPass(pass_func, 1, "Normalize", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.Normalize").set_body_typed(Normalize); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.Normalize", Normalize); +}); Pass NormalizeGlobalVar() { auto pass_func = [=](IRModule mod, PassContext pc) { @@ -290,7 +294,10 @@ Pass NormalizeGlobalVar() { /*pass_name=*/"NormalizeGlobalVar", /*required=*/{}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.NormalizeGlobalVar").set_body_typed(NormalizeGlobalVar); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.NormalizeGlobalVar", NormalizeGlobalVar); +}); } // namespace transform diff --git a/src/relax/transform/realize_vdevice.cc b/src/relax/transform/realize_vdevice.cc index ee4773fb3a24..d557920a29f6 100644 --- a/src/relax/transform/realize_vdevice.cc +++ b/src/relax/transform/realize_vdevice.cc @@ -21,6 +21,7 @@ * \file tvm/relax/transform/realize_vdevice.cc * \brief Propagate virtual device information. */ +#include #include #include #include @@ -415,7 +416,10 @@ Pass RealizeVDevice() { /*required=*/{}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.RealizeVDevice").set_body_typed(RealizeVDevice); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.RealizeVDevice", RealizeVDevice); +}); } // namespace transform } // namespace relax diff --git a/src/relax/transform/remove_purity_checking.cc b/src/relax/transform/remove_purity_checking.cc index 001ed6d339af..f6a990ec1038 100644 --- a/src/relax/transform/remove_purity_checking.cc +++ b/src/relax/transform/remove_purity_checking.cc @@ -20,6 +20,7 @@ * \file src/relax/transform/remove_purity_checking.cc * \brief Apply kForcePure in all pure functions and unwrap all calls to pure overrides */ +#include #include #include #include @@ -88,8 +89,10 @@ Pass RemovePurityChecking() { return CreateFunctionPass(pass_func, 0, "RemovePurityChecking", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.RemovePurityChecking") - .set_body_typed(RemovePurityChecking); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.RemovePurityChecking", RemovePurityChecking); +}); } // namespace transform diff --git a/src/relax/transform/remove_unused_outputs.cc b/src/relax/transform/remove_unused_outputs.cc index ea8a8fa14f29..ad1760fbbbf4 100644 --- a/src/relax/transform/remove_unused_outputs.cc +++ b/src/relax/transform/remove_unused_outputs.cc @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include #include @@ -336,7 +337,10 @@ Pass RemoveUnusedOutputs() { "RemoveUnusedOutputs"); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.RemoveUnusedOutputs").set_body_typed(RemoveUnusedOutputs); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.RemoveUnusedOutputs", RemoveUnusedOutputs); +}); } // namespace transform diff --git a/src/relax/transform/remove_unused_parameters.cc b/src/relax/transform/remove_unused_parameters.cc index 5018232668b9..985ce4af22c9 100644 --- a/src/relax/transform/remove_unused_parameters.cc +++ b/src/relax/transform/remove_unused_parameters.cc @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include #include @@ -250,8 +251,10 @@ Pass RemoveUnusedParameters() { return CreateModulePass(pass_func, 0, "RemoveUnusedParameters", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.RemoveUnusedParameters") - .set_body_typed(RemoveUnusedParameters); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.RemoveUnusedParameters", RemoveUnusedParameters); +}); } // namespace transform diff --git a/src/relax/transform/reorder_permute_dims_after_concat.cc b/src/relax/transform/reorder_permute_dims_after_concat.cc index 2016c6766c08..df4ba45d5f4b 100644 --- a/src/relax/transform/reorder_permute_dims_after_concat.cc +++ b/src/relax/transform/reorder_permute_dims_after_concat.cc @@ -22,6 +22,7 @@ * \brief Reorder concat(permute_dims(A), permute_dims(B)) into permute_dims(concat(A,B)) */ +#include #include #include #include @@ -173,8 +174,11 @@ Pass ReorderPermuteDimsAfterConcat() { return CreateFunctionPass(pass_func, 1, "ReorderPermuteDimsAfterConcat", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.ReorderPermuteDimsAfterConcat") - .set_body_typed(ReorderPermuteDimsAfterConcat); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.ReorderPermuteDimsAfterConcat", + ReorderPermuteDimsAfterConcat); +}); } // namespace transform } // namespace relax diff --git a/src/relax/transform/reorder_take_after_matmul.cc b/src/relax/transform/reorder_take_after_matmul.cc index 4c87cbe8b7e3..6793e1830688 100644 --- a/src/relax/transform/reorder_take_after_matmul.cc +++ b/src/relax/transform/reorder_take_after_matmul.cc @@ -22,6 +22,7 @@ * \brief Expand `matmul(x, A+B)` to `matmul(x, A) + matmul(x,B)` */ +#include #include #include #include @@ -156,8 +157,10 @@ Pass ReorderTakeAfterMatmul() { return CreateFunctionPass(pass_func, 1, "ReorderTakeAfterMatmul", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.ReorderTakeAfterMatmul") - .set_body_typed(ReorderTakeAfterMatmul); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.ReorderTakeAfterMatmul", ReorderTakeAfterMatmul); +}); } // namespace transform } // namespace relax diff --git a/src/relax/transform/rewrite_cuda_graph.cc b/src/relax/transform/rewrite_cuda_graph.cc index 14e98ecad152..c26bf887e4d6 100644 --- a/src/relax/transform/rewrite_cuda_graph.cc +++ b/src/relax/transform/rewrite_cuda_graph.cc @@ -49,6 +49,7 @@ * 2. Lift the regions identified in step 1 to a separate function and rewrite the original function * with `CUDAGraphRewriter`. */ +#include #include #include #include @@ -897,7 +898,10 @@ Pass RewriteCUDAGraph() { return CreateModulePass(pass_func, 0, "RewriteCUDAGraph", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.RewriteCUDAGraph").set_body_typed(RewriteCUDAGraph); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.RewriteCUDAGraph", RewriteCUDAGraph); +}); } // namespace transform diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc b/src/relax/transform/rewrite_dataflow_reshape.cc index a13c23387821..b491b96c577f 100644 --- a/src/relax/transform/rewrite_dataflow_reshape.cc +++ b/src/relax/transform/rewrite_dataflow_reshape.cc @@ -21,6 +21,7 @@ * \brief Transform all reshape within dataflow block to a relax.reshape operator */ #include +#include #include #include #include @@ -165,8 +166,10 @@ Pass RewriteDataflowReshape() { return CreateFunctionPass(pass_func, 0, "RewriteDataflowReshape", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.RewriteDataflowReshape") - .set_body_typed(RewriteDataflowReshape); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.RewriteDataflowReshape", RewriteDataflowReshape); +}); } // namespace transform diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index 33d3f485a5e0..0dadbebbffde 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -23,6 +23,7 @@ * \brief Run codegen for annotated relax functions. */ +#include #include #include #include @@ -220,7 +221,10 @@ Pass RunCodegen(Optional>> target_options, return CreateModulePass(pass_func, 0, "RunCodegen", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.RunCodegen").set_body_typed(RunCodegen); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.RunCodegen", RunCodegen); +}); } // namespace transform } // namespace tvm diff --git a/src/relax/transform/split_call_tir_by_pattern.cc b/src/relax/transform/split_call_tir_by_pattern.cc index 276ba448cc4b..bb4142900ebf 100644 --- a/src/relax/transform/split_call_tir_by_pattern.cc +++ b/src/relax/transform/split_call_tir_by_pattern.cc @@ -21,6 +21,7 @@ * \brief Transform all dataflow structure to non-dataflow version. */ #include +#include #include #include #include @@ -774,8 +775,10 @@ Pass SplitCallTIRByPattern(Array patterns, FCodegen fcodegen) { /*pass_name=*/"SplitCallTIRByPattern", // /*required=*/{}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.SplitCallTIRByPattern") - .set_body_typed(SplitCallTIRByPattern); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.SplitCallTIRByPattern", SplitCallTIRByPattern); +}); } // namespace transform diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc b/src/relax/transform/split_layout_rewrite_preproc.cc index 7990beb04b2e..a2f67bc65dd5 100644 --- a/src/relax/transform/split_layout_rewrite_preproc.cc +++ b/src/relax/transform/split_layout_rewrite_preproc.cc @@ -21,6 +21,7 @@ * \file src/relax/transform/split_tir_layout_rewrite.cc * \brief Use for rewriting the TIRs after meta_schedule layout rewrite post process. */ +#include #include #include #include @@ -340,7 +341,9 @@ Pass SplitLayoutRewritePreproc() { return tvm::transform::Sequential({pass, relax::transform::DeadCodeElimination()}, "SplitLayoutRewritePreproc"); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.SplitLayoutRewritePreproc") - .set_body_typed(SplitLayoutRewritePreproc); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.SplitLayoutRewritePreproc", SplitLayoutRewritePreproc); +}); } // namespace transform } // namespace tvm diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 7521b21d9418..58e3df3b7e3a 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -66,6 +66,7 @@ * during memory planning. */ #include +#include #include #include #include @@ -983,8 +984,10 @@ Pass StaticPlanBlockMemory() { return CreateModulePass(pass_func, /*opt_level=*/0, "StaticPlanBlockMemory", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.StaticPlanBlockMemory") - .set_body_typed(StaticPlanBlockMemory); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.StaticPlanBlockMemory", StaticPlanBlockMemory); +}); } // namespace transform } // namespace relax diff --git a/src/relax/transform/to_mixed_precision.cc b/src/relax/transform/to_mixed_precision.cc index d14b6ab14258..b22654181c06 100644 --- a/src/relax/transform/to_mixed_precision.cc +++ b/src/relax/transform/to_mixed_precision.cc @@ -21,6 +21,7 @@ * \brief Automatic mixed precision pass. */ +#include #include #include #include @@ -618,7 +619,10 @@ Pass ToMixedPrecision(const DataType& out_dtype, Optional> fp16_in return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.ToMixedPrecision").set_body_typed(ToMixedPrecision); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.ToMixedPrecision", ToMixedPrecision); +}); } // namespace transform diff --git a/src/relax/transform/to_non_dataflow.cc b/src/relax/transform/to_non_dataflow.cc index ef1616c83ed8..6516c794a34a 100644 --- a/src/relax/transform/to_non_dataflow.cc +++ b/src/relax/transform/to_non_dataflow.cc @@ -20,6 +20,7 @@ * \file src/relax/transform/to_non_dataflow.cc * \brief Transform all dataflow structure to non-dataflow version. */ +#include #include #include #include @@ -61,7 +62,10 @@ Pass ToNonDataflow() { return CreateFunctionPass(pass_func, 0, "ToNonDataflow", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.ToNonDataflow").set_body_typed(ToNonDataflow); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.ToNonDataflow", ToNonDataflow); +}); } // namespace transform diff --git a/src/relax/transform/topological_sort.cc b/src/relax/transform/topological_sort.cc index 1ba78cdc5e2c..fcb5011731f6 100644 --- a/src/relax/transform/topological_sort.cc +++ b/src/relax/transform/topological_sort.cc @@ -20,6 +20,7 @@ * \file src/relax/transform/topological_sort.cc * \brief Perform a topological sort of Dataflow blocks */ +#include #include #include #include @@ -342,34 +343,37 @@ Pass TopologicalSort(TraversalOrder order, StartingLocation starting_location) { return relax::transform::CreateFunctionPass(pass_func, 0, "TopologicalSort", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.TopologicalSort") - .set_body_typed([](String order_str, String direction_str) -> Pass { - TraversalOrder order = [&]() { - if (order_str == "depth-first") { - return TraversalOrder::DepthFirst; - } else if (order_str == "breadth-first") { - return TraversalOrder::BreadthFirst; - } else { - LOG(FATAL) << "ValueError: " - << "Invalid value for traversal order: \"" << order_str << "\". " - << "Allowed values are \"depth-first\" or \"breadth-first\""; - } - }(); - - StartingLocation starting_location = [&]() { - if (direction_str == "from-inputs") { - return StartingLocation::FromInputs; - } else if (direction_str == "from-outputs") { - return StartingLocation::FromOutputs; - } else { - LOG(FATAL) << "ValueError: " - << "Invalid value for starting location: \"" << direction_str << "\". " - << "Allowed values are \"from-inputs\" or \"from-outputs\""; - } - }(); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "relax.transform.TopologicalSort", [](String order_str, String direction_str) -> Pass { + TraversalOrder order = [&]() { + if (order_str == "depth-first") { + return TraversalOrder::DepthFirst; + } else if (order_str == "breadth-first") { + return TraversalOrder::BreadthFirst; + } else { + LOG(FATAL) << "ValueError: " + << "Invalid value for traversal order: \"" << order_str << "\". " + << "Allowed values are \"depth-first\" or \"breadth-first\""; + } + }(); + + StartingLocation starting_location = [&]() { + if (direction_str == "from-inputs") { + return StartingLocation::FromInputs; + } else if (direction_str == "from-outputs") { + return StartingLocation::FromOutputs; + } else { + LOG(FATAL) << "ValueError: " + << "Invalid value for starting location: \"" << direction_str << "\". " + << "Allowed values are \"from-inputs\" or \"from-outputs\""; + } + }(); - return TopologicalSort(order, starting_location); - }); + return TopologicalSort(order, starting_location); + }); +}); } // namespace transform diff --git a/src/relax/transform/update_param_struct_info.cc b/src/relax/transform/update_param_struct_info.cc index 472f454bc11a..091448964d21 100644 --- a/src/relax/transform/update_param_struct_info.cc +++ b/src/relax/transform/update_param_struct_info.cc @@ -22,6 +22,7 @@ * \brief Mutate IRModule to accept new parameters */ +#include #include #include #include @@ -104,8 +105,10 @@ Pass UpdateParamStructInfo(ffi::TypedFunction(Var)> sinfo_f return tvm::transform::CreateModulePass(pass_func, 1, "UpdateParamStructInfo", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.UpdateParamStructInfo") - .set_body_typed(UpdateParamStructInfo); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.UpdateParamStructInfo", UpdateParamStructInfo); +}); } // namespace transform } // namespace relax diff --git a/src/relax/transform/update_vdevice.cc b/src/relax/transform/update_vdevice.cc index d2a1f85be853..ce5e40048a93 100644 --- a/src/relax/transform/update_vdevice.cc +++ b/src/relax/transform/update_vdevice.cc @@ -23,6 +23,7 @@ * \brief Update Virtual Device pass. */ +#include #include #include #include @@ -106,7 +107,10 @@ Pass UpdateVDevice(VDevice new_vdevice, int64_t index) { /*pass_name=*/"UpdateVDevice", /*required=*/{}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.UpdateVDevice").set_body_typed(UpdateVDevice); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.UpdateVDevice", UpdateVDevice); +}); } // namespace transform } // namespace relax diff --git a/src/relax/utils.cc b/src/relax/utils.cc index ab270c08a65d..88ecbbe64e8d 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -19,6 +19,7 @@ #include "transform/utils.h" +#include #include #include #include @@ -245,7 +246,10 @@ Expr GetBoundValue(const Binding& b) { */ Function CopyWithNewVars(Function func) { return FunctionCopier().Copy(func); } -TVM_FFI_REGISTER_GLOBAL("relax.CopyWithNewVars").set_body_typed(CopyWithNewVars); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.CopyWithNewVars", CopyWithNewVars); +}); } // namespace relax } // namespace tvm diff --git a/src/runtime/const_loader_module.cc b/src/runtime/const_loader_module.cc index 84cd4943c552..1ba8b9c2a849 100644 --- a/src/runtime/const_loader_module.cc +++ b/src/runtime/const_loader_module.cc @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -247,8 +248,11 @@ Module ConstLoaderModuleCreate( return Module(n); } -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_const_loader") - .set_body_typed(ConstLoaderModuleNode::LoadFromBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("runtime.module.loadbinary_const_loader", + ConstLoaderModuleNode::LoadFromBinary); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/amx/amx_config.cc b/src/runtime/contrib/amx/amx_config.cc index 1eb63a10fa4c..f60759134837 100644 --- a/src/runtime/contrib/amx/amx_config.cc +++ b/src/runtime/contrib/amx/amx_config.cc @@ -22,6 +22,7 @@ * \brief extraction of AMX configuration on x86 platforms */ #include +#include namespace tvm { namespace runtime { @@ -75,59 +76,64 @@ void init_tile_config(__tilecfg_u* dst, uint16_t cols, uint8_t rows) { _tile_loadconfig(dst->a); } -TVM_FFI_REGISTER_GLOBAL("runtime.amx_tileconfig") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - int rows = args[0].cast(); - int cols = args[1].cast(); - LOG(INFO) << "rows: " << rows << ", cols:" << cols; - // -----------Config for AMX tile resgister---------------------- - __tilecfg_u cfg; - init_tile_config(&cfg, cols, rows); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("runtime.amx_tileconfig", [](ffi::PackedArgs args, ffi::Any* rv) { + int rows = args[0].cast(); + int cols = args[1].cast(); + LOG(INFO) << "rows: " << rows << ", cols:" << cols; + // -----------Config for AMX tile resgister---------------------- + __tilecfg_u cfg; + init_tile_config(&cfg, cols, rows); + *rv = 1; + return; + }); +}); + +// register a global packed function in c++,to init the system for AMX config +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("runtime.amx_init", [](ffi::PackedArgs args, ffi::Any* rv) { + // -----------Detect and request for AMX control---------------------- + uint64_t bitmask = 0; + int64_t status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); + if (0 != status) { + *rv = 0; + LOG(FATAL) << "errno:" << errno << ", " << strerror(errno); + LOG(FATAL) << "status[0]: " << status << ", bitmask: " << bitmask + << ", XFEATURE_XTILEDATA setup is failed, TMUL feature is not allowed."; + return; + } + if (bitmask & XFEATURE_MASK_XTILEDATA) { *rv = 1; return; - }); + } // TILE_DATA feature was not detected + + status = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); + // if XFEATURE_XTILEDATA setup is failed, TMUL usage is not allowed + if (0 != status) { + *rv = 0; + LOG(FATAL) << "errno:" << errno << ", " << strerror(errno); + LOG(FATAL) << "status[1]: " << status << ", bitmask: " << bitmask + << ", XFEATURE_XTILEDATA setup is failed, TMUL usage is not allowed."; + return; + } + + status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); + // if XFEATURE_XTILEDATA setup is failed, can't use TMUL + if (0 != status || !(bitmask & XFEATURE_MASK_XTILEDATA)) { + *rv = 0; + LOG(FATAL) << "errno:" << errno << ", " << strerror(errno); + LOG(FATAL) << "status[2]: " << status << ", bitmask: " << bitmask + << ", XFEATURE_XTILEDATA setup is failed, can't use TMUL."; + return; + } -// register a global packed function in c++,to init the system for AMX config -TVM_FFI_REGISTER_GLOBAL("runtime.amx_init").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - // -----------Detect and request for AMX control---------------------- - uint64_t bitmask = 0; - int64_t status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); - if (0 != status) { - *rv = 0; - LOG(FATAL) << "errno:" << errno << ", " << strerror(errno); - LOG(FATAL) << "status[0]: " << status << ", bitmask: " << bitmask - << ", XFEATURE_XTILEDATA setup is failed, TMUL feature is not allowed."; - return; - } - if (bitmask & XFEATURE_MASK_XTILEDATA) { + // XFEATURE_XTILEDATA set successfully, TMUL usage is allowed *rv = 1; return; - } // TILE_DATA feature was not detected - - status = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); - // if XFEATURE_XTILEDATA setup is failed, TMUL usage is not allowed - if (0 != status) { - *rv = 0; - LOG(FATAL) << "errno:" << errno << ", " << strerror(errno); - LOG(FATAL) << "status[1]: " << status << ", bitmask: " << bitmask - << ", XFEATURE_XTILEDATA setup is failed, TMUL usage is not allowed."; - return; - } - - status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); - // if XFEATURE_XTILEDATA setup is failed, can't use TMUL - if (0 != status || !(bitmask & XFEATURE_MASK_XTILEDATA)) { - *rv = 0; - LOG(FATAL) << "errno:" << errno << ", " << strerror(errno); - LOG(FATAL) << "status[2]: " << status << ", bitmask: " << bitmask - << ", XFEATURE_XTILEDATA setup is failed, can't use TMUL."; - return; - } - - // XFEATURE_XTILEDATA set successfully, TMUL usage is allowed - *rv = 1; - return; + }); }); #endif diff --git a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc index eeca2fcdf347..d2c62944745b 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include "../json/json_node.h" @@ -593,9 +594,13 @@ runtime::Module ACLRuntimeCreate(const String& symbol_name, const String& graph_ return runtime::Module(n); } -TVM_FFI_REGISTER_GLOBAL("runtime.arm_compute_lib_runtime_create").set_body_typed(ACLRuntimeCreate); -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_arm_compute_lib") - .set_body_typed(JSONRuntimeBase::LoadFromBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.arm_compute_lib_runtime_create", ACLRuntimeCreate) + .def("runtime.module.loadbinary_arm_compute_lib", + JSONRuntimeBase::LoadFromBinary); +}); } // namespace contrib } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc index aed0080589e0..5177d55319bd 100644 --- a/src/runtime/contrib/bnns/bnns_json_runtime.cc +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include @@ -562,10 +563,12 @@ runtime::Module BNNSJSONRuntimeCreate(String symbol_name, String graph_json, return runtime::Module(n); } -TVM_FFI_REGISTER_GLOBAL("runtime.BNNSJSONRuntimeCreate").set_body_typed(BNNSJSONRuntimeCreate); - -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_bnns_json") - .set_body_typed(BNNSJSONRuntime::LoadFromBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.BNNSJSONRuntimeCreate", BNNSJSONRuntimeCreate) + .def("runtime.module.loadbinary_bnns_json", BNNSJSONRuntime::LoadFromBinary); +}); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/cblas/cblas.cc b/src/runtime/contrib/cblas/cblas.cc index 4d04d8263447..23e55cd276c8 100644 --- a/src/runtime/contrib/cblas/cblas.cc +++ b/src/runtime/contrib/cblas/cblas.cc @@ -21,6 +21,7 @@ * \file Use external cblas library call. */ #include +#include #include #include @@ -123,37 +124,39 @@ struct CblasDgemmBatchIterativeOp { }; // matrix multiplication for row major -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - - if (TypeMatch(A->dtype, kDLFloat, 32)) - CallGemm(args, ret, CblasSgemmOp()); - else - CallGemm(args, ret, CblasDgemmOp()); - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 32)) { - CallBatchGemm(args, ret, CblasSgemmBatchOp()); - } else { - CallBatchGemm(args, ret, CblasDgemmBatchOp()); - } - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 32)) { - CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp()); - } else { - CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp()); - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("tvm.contrib.cblas.matmul", + [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + + if (TypeMatch(A->dtype, kDLFloat, 32)) + CallGemm(args, ret, CblasSgemmOp()); + else + CallGemm(args, ret, CblasDgemmOp()); + }) + .def_packed("tvm.contrib.cblas.batch_matmul", + [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, CblasSgemmBatchOp()); + } else { + CallBatchGemm(args, ret, CblasDgemmBatchOp()); + } + }) + .def_packed("tvm.contrib.cblas.batch_matmul_iterative", + [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp()); + } else { + CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp()); + } + }); +}); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cblas/dnnl_blas.cc b/src/runtime/contrib/cblas/dnnl_blas.cc index 68819d015326..dd36953d279a 100644 --- a/src/runtime/contrib/cblas/dnnl_blas.cc +++ b/src/runtime/contrib/cblas/dnnl_blas.cc @@ -21,6 +21,7 @@ * \file Use external cblas library call. */ #include +#include #include #include @@ -46,11 +47,13 @@ struct DNNLSgemmOp { }; // matrix multiplication for row major -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.dnnl.matmul") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); - CallGemm(args, ret, DNNLSgemmOp()); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("tvm.contrib.dnnl.matmul", [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); + CallGemm(args, ret, DNNLSgemmOp()); + }); +}); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cblas/mkl.cc b/src/runtime/contrib/cblas/mkl.cc index 33b52e5e375d..93e2cfa53313 100644 --- a/src/runtime/contrib/cblas/mkl.cc +++ b/src/runtime/contrib/cblas/mkl.cc @@ -21,6 +21,7 @@ * \file Use external mkl library call. */ #include +#include #include #include @@ -154,49 +155,53 @@ struct MKLDgemmBatchIterativeOp { }; // matrix multiplication for row major -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mkl.matmul") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - - if (TypeMatch(A->dtype, kDLFloat, 32)) - CallGemm(args, ret, MKLSgemmOp()); - else - CallGemm(args, ret, MKLDgemmOp()); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("tvm.contrib.mkl.matmul", [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + + if (TypeMatch(A->dtype, kDLFloat, 32)) + CallGemm(args, ret, MKLSgemmOp()); + else + CallGemm(args, ret, MKLDgemmOp()); + }); +}); // integer matrix multiplication for row major -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mkl.matmul_u8s8s32") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - auto B = args[1].cast(); - auto C = args[2].cast(); - ICHECK(TypeMatch(A->dtype, kDLUInt, 8) && TypeMatch(B->dtype, kDLInt, 8) && - TypeMatch(C->dtype, kDLInt, 32)); - - CallU8S8S32Gemm(args, ret, MKLGemmU8S8S32Op()); - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mkl.batch_matmul") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 32)) { - CallBatchGemm(args, ret, MKLSgemmBatchOp()); - } else { - CallBatchGemm(args, ret, MKLDgemmBatchOp()); - } - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mkl.batch_matmul_iterative") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 32)) { - CallBatchGemm(args, ret, MKLSgemmBatchIterativeOp()); - } else { - CallBatchGemm(args, ret, MKLDgemmBatchIterativeOp()); - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("tvm.contrib.mkl.matmul_u8s8s32", + [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + auto B = args[1].cast(); + auto C = args[2].cast(); + ICHECK(TypeMatch(A->dtype, kDLUInt, 8) && TypeMatch(B->dtype, kDLInt, 8) && + TypeMatch(C->dtype, kDLInt, 32)); + + CallU8S8S32Gemm(args, ret, MKLGemmU8S8S32Op()); + }) + .def_packed("tvm.contrib.mkl.batch_matmul", + [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, MKLSgemmBatchOp()); + } else { + CallBatchGemm(args, ret, MKLDgemmBatchOp()); + } + }) + .def_packed("tvm.contrib.mkl.batch_matmul_iterative", + [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, MKLSgemmBatchIterativeOp()); + } else { + CallBatchGemm(args, ret, MKLDgemmBatchIterativeOp()); + } + }); +}); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index 5ee90e29b009..d298eaf5d6e4 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -23,6 +23,8 @@ */ #include "clml_runtime.h" +#include + #include #ifdef TVM_GRAPH_EXECUTOR_CLML @@ -1830,9 +1832,12 @@ runtime::Module CLMLRuntimeCreate(const String& symbol_name, const String& graph return runtime::Module(n); } -TVM_FFI_REGISTER_GLOBAL("runtime.clml_runtime_create").set_body_typed(CLMLRuntimeCreate); -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_clml") - .set_body_typed(JSONRuntimeBase::LoadFromBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.clml_runtime_create", CLMLRuntimeCreate) + .def("runtime.module.loadbinary_clml", JSONRuntimeBase::LoadFromBinary); +}); } // namespace contrib } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/coreml/coreml_runtime.mm b/src/runtime/contrib/coreml/coreml_runtime.mm index f98c97f68b12..0e902bea9af2 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.mm +++ b/src/runtime/contrib/coreml/coreml_runtime.mm @@ -21,6 +21,7 @@ * \file coreml_runtime.cc */ #include +#include #include "coreml_runtime.h" @@ -192,10 +193,12 @@ Module CoreMLRuntimeCreate(const std::string& symbol, const std::string& model_p return Module(exec); } -TVM_FFI_REGISTER_GLOBAL("tvm.coreml_runtime.create") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = CoreMLRuntimeCreate(args[0], args[1]); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("tvm.coreml_runtime.create", [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = CoreMLRuntimeCreate(args[0], args[1]); + }); +}); void CoreMLRuntime::SaveToBinary(dmlc::Stream* stream) { NSURL* url = model_->url_; @@ -249,8 +252,10 @@ Module CoreMLRuntimeLoadFromBinary(void* strm) { return Module(exec); } -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_coreml") - .set_body_typed(CoreMLRuntimeLoadFromBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("runtime.module.loadbinary_coreml", CoreMLRuntimeLoadFromBinary); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 19d83e624d91..36607973de40 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -21,6 +21,7 @@ * \file Use external cblas library call. */ #include +#include #include #include @@ -514,71 +515,80 @@ inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, cublasHandle_t } // matrix multiplication for row major -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cublas.matmul") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - auto C = args[2].cast(); - - CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); - - CUBLASTryEnableTensorCore(entry_ptr->handle); - - if (TypeEqual(A->dtype, C->dtype)) { - ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || - TypeMatch(A->dtype, kDLFloat, 64)); - - if (TypeMatch(A->dtype, kDLFloat, 16)) - CallGemm(args, ret, CublasHgemmOp(entry_ptr->handle)); - else if (TypeMatch(A->dtype, kDLFloat, 32)) - CallGemm(args, ret, CublasSgemmOp(entry_ptr->handle)); - else - CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle)); - } else { - CallGemmEx(args, ret, entry_ptr->handle); - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed( + "tvm.contrib.cublas.matmul", [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + auto C = args[2].cast(); + + CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); + + CUBLASTryEnableTensorCore(entry_ptr->handle); + + if (TypeEqual(A->dtype, C->dtype)) { + ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || + TypeMatch(A->dtype, kDLFloat, 64)); + + if (TypeMatch(A->dtype, kDLFloat, 16)) + CallGemm(args, ret, CublasHgemmOp(entry_ptr->handle)); + else if (TypeMatch(A->dtype, kDLFloat, 32)) + CallGemm(args, ret, CublasSgemmOp(entry_ptr->handle)); + else + CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle)); + } else { + CallGemmEx(args, ret, entry_ptr->handle); + } + }); +}); #if CUDART_VERSION >= 10010 -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - - CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); - - CUBLASTryEnableTensorCore(entry_ptr->handle); - - ICHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n"; - cublasLtHandle_t ltHandle; - CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle)); - auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); - cudaStream_t stream = static_cast(func().cast()); - CallLtIgemm(args, ret, ltHandle, stream); - CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle)); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed( + "tvm.contrib.cublaslt.matmul", [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + + CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); + + CUBLASTryEnableTensorCore(entry_ptr->handle); + + ICHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n"; + cublasLtHandle_t ltHandle; + CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle)); + auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); + cudaStream_t stream = static_cast(func().cast()); + CallLtIgemm(args, ret, ltHandle, stream); + CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle)); + }); +}); #endif // CUDART_VERSION >= 10010 -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - auto C = args[2].cast(); - - CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); - - CUBLASTryEnableTensorCore(entry_ptr->handle); - if (TypeEqual(A->dtype, C->dtype)) { - ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || - TypeMatch(A->dtype, kDLFloat, 64)); - - if (TypeMatch(A->dtype, kDLFloat, 16)) - CallBatchGemm(args, ret, CublasHgemmBatchOp(entry_ptr->handle)); - else if (TypeMatch(A->dtype, kDLFloat, 32)) - CallBatchGemm(args, ret, CublasSgemmBatchOp(entry_ptr->handle)); - else - CallBatchGemm(args, ret, CublasDgemmBatchOp(entry_ptr->handle)); - } else { - CallBatchGemmEx(args, ret, entry_ptr->handle); - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed( + "tvm.contrib.cublas.batch_matmul", [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + auto C = args[2].cast(); + + CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); + + CUBLASTryEnableTensorCore(entry_ptr->handle); + if (TypeEqual(A->dtype, C->dtype)) { + ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || + TypeMatch(A->dtype, kDLFloat, 64)); + + if (TypeMatch(A->dtype, kDLFloat, 16)) + CallBatchGemm(args, ret, CublasHgemmBatchOp(entry_ptr->handle)); + else if (TypeMatch(A->dtype, kDLFloat, 32)) + CallBatchGemm(args, ret, CublasSgemmBatchOp(entry_ptr->handle)); + else + CallBatchGemm(args, ret, CublasDgemmBatchOp(entry_ptr->handle)); + } else { + CallBatchGemmEx(args, ret, entry_ptr->handle); + } + }); +}); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index 8f7b6ac1f188..7d05cf56bdd7 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include @@ -153,10 +154,13 @@ runtime::Module CublasJSONRuntimeCreate(String symbol_name, String graph_json, return runtime::Module(n); } -TVM_FFI_REGISTER_GLOBAL("runtime.CublasJSONRuntimeCreate").set_body_typed(CublasJSONRuntimeCreate); - -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_cublas_json") - .set_body_typed(JSONRuntimeBase::LoadFromBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.CublasJSONRuntimeCreate", CublasJSONRuntimeCreate) + .def("runtime.module.loadbinary_cublas_json", + JSONRuntimeBase::LoadFromBinary); +}); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/cudnn/conv_backward.cc b/src/runtime/contrib/cudnn/conv_backward.cc index a19fc192efd1..cf049407fc25 100644 --- a/src/runtime/contrib/cudnn/conv_backward.cc +++ b/src/runtime/contrib/cudnn/conv_backward.cc @@ -21,6 +21,7 @@ * \file cuDNN kernel calls for backward algorithms. */ #include +#include #include #include @@ -185,85 +186,86 @@ void BackwardFilterFindAlgo(int format, int dims, int groups, const int pad[], c ret[0] = static_cast(best_algo); } -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_data") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - int mode = args[0].cast(); - int format = args[1].cast(); - int algo = args[2].cast(); - int pad_v[2], stride_v[2], dilation_v[2]; - for (int i = 0; i < 2; i++) { - pad_v[i] = args[3 + i].cast(); - stride_v[i] = args[5 + i].cast(); - dilation_v[i] = args[7 + i].cast(); - } - auto dy = args[9].cast(); - auto w = args[10].cast(); - auto dx = args[11].cast(); - auto conv_dtype = args[12].cast(); - int groups = args[13].cast(); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("tvm.contrib.cudnn.conv2d.backward_data", + [](ffi::PackedArgs args, ffi::Any* ret) { + int mode = args[0].cast(); + int format = args[1].cast(); + int algo = args[2].cast(); + int pad_v[2], stride_v[2], dilation_v[2]; + for (int i = 0; i < 2; i++) { + pad_v[i] = args[3 + i].cast(); + stride_v[i] = args[5 + i].cast(); + dilation_v[i] = args[7 + i].cast(); + } + auto dy = args[9].cast(); + auto w = args[10].cast(); + auto dx = args[11].cast(); + auto conv_dtype = args[12].cast(); + int groups = args[13].cast(); - ConvolutionBackwardData(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, dy, w, dx, - conv_dtype); - }); + ConvolutionBackwardData(mode, format, algo, 2, groups, pad_v, stride_v, + dilation_v, dy, w, dx, conv_dtype); + }) + .def_packed("tvm.contrib.cudnn.conv.backward_data_find_algo", + [](ffi::PackedArgs args, ffi::Any* ret) { + int format = args[0].cast(); + int dims = args[1].cast(); + int* pad = static_cast(args[2].cast()); + int* stride = static_cast(args[3].cast()); + int* dilation = static_cast(args[4].cast()); + int* dy_dim = static_cast(args[5].cast()); + int* w_dim = static_cast(args[6].cast()); + int* dx_dim = static_cast(args[7].cast()); + auto data_dtype = args[8].cast(); + auto conv_dtype = args[9].cast(); + int groups = args[10].cast(); + bool verbose = args[11].cast(); -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_data_find_algo") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - int format = args[0].cast(); - int dims = args[1].cast(); - int* pad = static_cast(args[2].cast()); - int* stride = static_cast(args[3].cast()); - int* dilation = static_cast(args[4].cast()); - int* dy_dim = static_cast(args[5].cast()); - int* w_dim = static_cast(args[6].cast()); - int* dx_dim = static_cast(args[7].cast()); - auto data_dtype = args[8].cast(); - auto conv_dtype = args[9].cast(); - int groups = args[10].cast(); - bool verbose = args[11].cast(); + BackwardDataFindAlgo(format, dims, groups, pad, stride, dilation, dy_dim, w_dim, + dx_dim, data_dtype, conv_dtype, verbose, ret); + }) + .def_packed("tvm.contrib.cudnn.conv2d.backward_filter", + [](ffi::PackedArgs args, ffi::Any* ret) { + int mode = args[0].cast(); + int format = args[1].cast(); + int algo = args[2].cast(); + int pad_v[2], stride_v[2], dilation_v[2]; + for (int i = 0; i < 2; i++) { + pad_v[i] = args[3 + i].cast(); + stride_v[i] = args[5 + i].cast(); + dilation_v[i] = args[7 + i].cast(); + } + auto dy = args[9].cast(); + auto x = args[10].cast(); + auto dw = args[11].cast(); + auto conv_dtype = args[12].cast(); + int groups = args[13].cast(); - BackwardDataFindAlgo(format, dims, groups, pad, stride, dilation, dy_dim, w_dim, dx_dim, - data_dtype, conv_dtype, verbose, ret); - }); + ConvolutionBackwardFilter(mode, format, algo, 2, groups, pad_v, stride_v, + dilation_v, dy, x, dw, conv_dtype); + }) + .def_packed("tvm.contrib.cudnn.conv.backward_filter_find_algo", + [](ffi::PackedArgs args, ffi::Any* ret) { + int format = args[0].cast(); + int dims = args[1].cast(); + int* pad = static_cast(args[2].cast()); + int* stride = static_cast(args[3].cast()); + int* dilation = static_cast(args[4].cast()); + int* dy_dim = static_cast(args[5].cast()); + int* x_dim = static_cast(args[6].cast()); + int* dw_dim = static_cast(args[7].cast()); + auto data_dtype = args[8].cast(); + auto conv_dtype = args[9].cast(); + int groups = args[10].cast(); + bool verbose = args[11].cast(); -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_filter") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - int mode = args[0].cast(); - int format = args[1].cast(); - int algo = args[2].cast(); - int pad_v[2], stride_v[2], dilation_v[2]; - for (int i = 0; i < 2; i++) { - pad_v[i] = args[3 + i].cast(); - stride_v[i] = args[5 + i].cast(); - dilation_v[i] = args[7 + i].cast(); - } - auto dy = args[9].cast(); - auto x = args[10].cast(); - auto dw = args[11].cast(); - auto conv_dtype = args[12].cast(); - int groups = args[13].cast(); - - ConvolutionBackwardFilter(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, dy, x, - dw, conv_dtype); - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_filter_find_algo") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - int format = args[0].cast(); - int dims = args[1].cast(); - int* pad = static_cast(args[2].cast()); - int* stride = static_cast(args[3].cast()); - int* dilation = static_cast(args[4].cast()); - int* dy_dim = static_cast(args[5].cast()); - int* x_dim = static_cast(args[6].cast()); - int* dw_dim = static_cast(args[7].cast()); - auto data_dtype = args[8].cast(); - auto conv_dtype = args[9].cast(); - int groups = args[10].cast(); - bool verbose = args[11].cast(); - - BackwardFilterFindAlgo(format, dims, groups, pad, stride, dilation, dy_dim, x_dim, dw_dim, - data_dtype, conv_dtype, verbose, ret); - }); + BackwardFilterFindAlgo(format, dims, groups, pad, stride, dilation, dy_dim, + x_dim, dw_dim, data_dtype, conv_dtype, verbose, ret); + }); +}); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index 856d796e9038..14487be359c1 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -21,6 +21,7 @@ * \file cuDNN kernel calls for the forward algorithm. */ #include +#include #include #include @@ -153,89 +154,91 @@ void FindAlgo(int format, int dims, int groups, const int pad[], const int strid ret[0] = static_cast(best_algo); } -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - int mode = args[0].cast(); - int format = args[1].cast(); - int algo = args[2].cast(); - int pad_v[2], stride_v[2], dilation_v[2]; - for (int i = 0; i < 2; i++) { - pad_v[i] = args[3 + i].cast(); - stride_v[i] = args[5 + i].cast(); - dilation_v[i] = args[7 + i].cast(); - } - auto x = args[9].cast(); - auto w = args[10].cast(); - auto y = args[11].cast(); - auto conv_dtype = args[12].cast(); - int groups = args[13].cast(); - - ConvolutionForward(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, x, w, y, - conv_dtype); - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d+bias+act.forward") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - int mode = args[0].cast(); - int format = args[1].cast(); - int algo = args[2].cast(); - int pad_v[2], stride_v[2], dilation_v[2]; - for (int i = 0; i < 2; i++) { - pad_v[i] = args[3 + i].cast(); - stride_v[i] = args[5 + i].cast(); - dilation_v[i] = args[7 + i].cast(); - } - int act = args[9].cast(); - double coef = args[10].cast(); - auto x = args[11].cast(); - auto w = args[12].cast(); - auto bias = args[13].cast(); - auto y = args[14].cast(); - auto conv_dtype = args[15].cast(); - int groups = args[16].cast(); - - ConvolutionBiasActivationForward(mode, format, algo, 2, groups, act, coef, pad_v, stride_v, - dilation_v, x, w, y, bias, conv_dtype); - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - int mode = args[0].cast(); - int format = args[1].cast(); - int algo = args[2].cast(); - int pad_v[3], stride_v[3], dilation_v[3]; - for (int i = 0; i < 3; i++) { - pad_v[i] = args[3 + i].cast(); - stride_v[i] = args[6 + i].cast(); - dilation_v[i] = args[9 + i].cast(); - } - auto x = args[12].cast(); - auto w = args[13].cast(); - auto y = args[14].cast(); - auto conv_dtype = args[15].cast(); - int groups = args[16].cast(); - - ConvolutionForward(mode, format, algo, 3, groups, pad_v, stride_v, dilation_v, x, w, y, - conv_dtype); - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.forward_find_algo") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - int format = args[0].cast(); - int dims = args[1].cast(); - int* pad = static_cast(args[2].cast()); - int* stride = static_cast(args[3].cast()); - int* dilation = static_cast(args[4].cast()); - int* x_dim = static_cast(args[5].cast()); - int* w_dim = static_cast(args[6].cast()); - int* y_dim = static_cast(args[7].cast()); - auto data_dtype = args[8].cast(); - auto conv_dtype = args[9].cast(); - int groups = args[10].cast(); - bool verbose = args[11].cast(); - FindAlgo(format, dims, groups, pad, stride, dilation, x_dim, w_dim, y_dim, data_dtype, - conv_dtype, verbose, ret); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("tvm.contrib.cudnn.conv2d.forward", + [](ffi::PackedArgs args, ffi::Any* ret) { + int mode = args[0].cast(); + int format = args[1].cast(); + int algo = args[2].cast(); + int pad_v[2], stride_v[2], dilation_v[2]; + for (int i = 0; i < 2; i++) { + pad_v[i] = args[3 + i].cast(); + stride_v[i] = args[5 + i].cast(); + dilation_v[i] = args[7 + i].cast(); + } + auto x = args[9].cast(); + auto w = args[10].cast(); + auto y = args[11].cast(); + auto conv_dtype = args[12].cast(); + int groups = args[13].cast(); + + ConvolutionForward(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, + x, w, y, conv_dtype); + }) + .def_packed("tvm.contrib.cudnn.conv2d+bias+act.forward", + [](ffi::PackedArgs args, ffi::Any* ret) { + int mode = args[0].cast(); + int format = args[1].cast(); + int algo = args[2].cast(); + int pad_v[2], stride_v[2], dilation_v[2]; + for (int i = 0; i < 2; i++) { + pad_v[i] = args[3 + i].cast(); + stride_v[i] = args[5 + i].cast(); + dilation_v[i] = args[7 + i].cast(); + } + int act = args[9].cast(); + double coef = args[10].cast(); + auto x = args[11].cast(); + auto w = args[12].cast(); + auto bias = args[13].cast(); + auto y = args[14].cast(); + auto conv_dtype = args[15].cast(); + int groups = args[16].cast(); + + ConvolutionBiasActivationForward(mode, format, algo, 2, groups, act, coef, + pad_v, stride_v, dilation_v, x, w, y, bias, + conv_dtype); + }) + .def_packed("tvm.contrib.cudnn.conv3d.forward", + [](ffi::PackedArgs args, ffi::Any* ret) { + int mode = args[0].cast(); + int format = args[1].cast(); + int algo = args[2].cast(); + int pad_v[3], stride_v[3], dilation_v[3]; + for (int i = 0; i < 3; i++) { + pad_v[i] = args[3 + i].cast(); + stride_v[i] = args[6 + i].cast(); + dilation_v[i] = args[9 + i].cast(); + } + auto x = args[12].cast(); + auto w = args[13].cast(); + auto y = args[14].cast(); + auto conv_dtype = args[15].cast(); + int groups = args[16].cast(); + + ConvolutionForward(mode, format, algo, 3, groups, pad_v, stride_v, dilation_v, + x, w, y, conv_dtype); + }) + .def_packed("tvm.contrib.cudnn.conv.forward_find_algo", + [](ffi::PackedArgs args, ffi::Any* ret) { + int format = args[0].cast(); + int dims = args[1].cast(); + int* pad = static_cast(args[2].cast()); + int* stride = static_cast(args[3].cast()); + int* dilation = static_cast(args[4].cast()); + int* x_dim = static_cast(args[5].cast()); + int* w_dim = static_cast(args[6].cast()); + int* y_dim = static_cast(args[7].cast()); + auto data_dtype = args[8].cast(); + auto conv_dtype = args[9].cast(); + int groups = args[10].cast(); + bool verbose = args[11].cast(); + FindAlgo(format, dims, groups, pad, stride, dilation, x_dim, w_dim, y_dim, + data_dtype, conv_dtype, verbose, ret); + }); +}); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc index eda3b694d7f0..e5268051df6e 100644 --- a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc +++ b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include @@ -237,10 +238,13 @@ runtime::Module cuDNNJSONRuntimeCreate(String symbol_name, String graph_json, return runtime::Module(n); } -TVM_FFI_REGISTER_GLOBAL("runtime.cuDNNJSONRuntimeCreate").set_body_typed(cuDNNJSONRuntimeCreate); - -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_cudnn_json") - .set_body_typed(JSONRuntimeBase::LoadFromBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.cuDNNJSONRuntimeCreate", cuDNNJSONRuntimeCreate) + .def("runtime.module.loadbinary_cudnn_json", + JSONRuntimeBase::LoadFromBinary); +}); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc b/src/runtime/contrib/cudnn/cudnn_utils.cc index 8e2e85c67524..f3f0f8b17547 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.cc +++ b/src/runtime/contrib/cudnn/cudnn_utils.cc @@ -25,6 +25,7 @@ #include #include +#include #include #include @@ -265,8 +266,10 @@ SoftmaxEntry::SoftmaxEntry() { CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_des SoftmaxEntry::~SoftmaxEntry() { CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc)); } -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.exists").set_body_typed([]() -> bool { - return CuDNNThreadEntry::ThreadLocal(false)->exists(); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tvm.contrib.cudnn.exists", + []() -> bool { return CuDNNThreadEntry::ThreadLocal(false)->exists(); }); }); } // namespace contrib diff --git a/src/runtime/contrib/cudnn/softmax.cc b/src/runtime/contrib/cudnn/softmax.cc index aa37acd2c3a9..60eb5402a7cd 100644 --- a/src/runtime/contrib/cudnn/softmax.cc +++ b/src/runtime/contrib/cudnn/softmax.cc @@ -22,6 +22,7 @@ * \brief Use external cudnn softmax function */ #include +#include #include #include "cudnn_utils.h" @@ -77,15 +78,17 @@ void softmax_impl(cudnnSoftmaxAlgorithm_t alg, ffi::PackedArgs args, ffi::Any* r entry_ptr->softmax_entry.shape_desc, y->data)); } -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - softmax_impl(CUDNN_SOFTMAX_ACCURATE, args, ret); - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.log_softmax.forward") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - softmax_impl(CUDNN_SOFTMAX_LOG, args, ret); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("tvm.contrib.cudnn.softmax.forward", + [](ffi::PackedArgs args, ffi::Any* ret) { + softmax_impl(CUDNN_SOFTMAX_ACCURATE, args, ret); + }) + .def_packed("tvm.contrib.cudnn.log_softmax.forward", [](ffi::PackedArgs args, ffi::Any* ret) { + softmax_impl(CUDNN_SOFTMAX_LOG, args, ret); + }); +}); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/curand/curand.cc b/src/runtime/contrib/curand/curand.cc index e31c5fdfebf8..9509f45f0753 100644 --- a/src/runtime/contrib/curand/curand.cc +++ b/src/runtime/contrib/curand/curand.cc @@ -18,6 +18,7 @@ */ #include #include +#include #include #include "../../cuda/cuda_common.h" @@ -112,7 +113,10 @@ void RandomFill(DLTensor* tensor) { TVMSynchronize(tensor->device.device_type, tensor->device.device_type, nullptr); } -TVM_FFI_REGISTER_GLOBAL("runtime.contrib.curand.RandomFill").set_body_typed(RandomFill); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("runtime.contrib.curand.RandomFill", RandomFill); +}); } // namespace curand } // namespace runtime diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu b/src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu index 29efcbe088ae..8149fab0aa4f 100644 --- a/src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu +++ b/src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -46,7 +47,10 @@ void tvm_cutlass_group_gemm_sm100(NDArray x, NDArray weight, NDArray indptr, NDA tvm_cutlass_group_gemm_impl<100>(x, weight, indptr, workspace, out); } -TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm").set_body_typed(tvm_cutlass_group_gemm_sm100); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("cutlass.group_gemm", tvm_cutlass_group_gemm_sm100); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu b/src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu index bbf5a453b4d0..5276ce9f820c 100644 --- a/src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu +++ b/src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu @@ -20,8 +20,8 @@ #include #include #include +#include #include -#include #include "fp16_group_gemm.cuh" #include "fp16_group_gemm_runner_sm90.cuh" @@ -46,7 +46,10 @@ void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr, NDAr tvm_cutlass_group_gemm_impl<90>(x, weight, indptr, workspace, out); } -TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm").set_body_typed(tvm_cutlass_group_gemm_sm90); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("cutlass.group_gemm", tvm_cutlass_group_gemm_sm90); +}); #endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED diff --git a/src/runtime/contrib/cutlass/fp8_gemm.cu b/src/runtime/contrib/cutlass/fp8_gemm.cu index 4ee31e73abca..fc5c76282b26 100644 --- a/src/runtime/contrib/cutlass/fp8_gemm.cu +++ b/src/runtime/contrib/cutlass/fp8_gemm.cu @@ -19,9 +19,9 @@ #include #include -#include -#include #include +#include +#include #include "../cublas/cublas_utils.h" #include "gemm_runner.cuh" @@ -77,17 +77,16 @@ void tvm_cutlass_fp8_gemm(NDArray x, NDArray weight, NDArray workspace, NDArray } } -TVM_FFI_REGISTER_GLOBAL("cutlass.gemm_e5m2_e5m2_fp16") - .set_body_typed( - tvm_cutlass_fp8_gemm); - -TVM_FFI_REGISTER_GLOBAL("cutlass.gemm_e5m2_e4m3_fp16") - .set_body_typed( - tvm_cutlass_fp8_gemm); - -TVM_FFI_REGISTER_GLOBAL("cutlass.gemm_e4m3_e4m3_fp16") - .set_body_typed( - tvm_cutlass_fp8_gemm); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("cutlass.gemm_e5m2_e5m2_fp16", + tvm_cutlass_fp8_gemm) + .def("cutlass.gemm_e5m2_e4m3_fp16", + tvm_cutlass_fp8_gemm) + .def("cutlass.gemm_e4m3_e4m3_fp16", + tvm_cutlass_fp8_gemm); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu index 0eaa6a1efb77..4477cccf7f33 100644 --- a/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu +++ b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "fp16_group_gemm_runner_sm90.cuh" @@ -66,17 +67,19 @@ void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, NDArray indptr, NDArr static_cast(out->data), stream); } -TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e5m2_fp16") - .set_body_typed( - tvm_cutlass_fp8_group_gemm); - -TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e4m3_fp16") - .set_body_typed( - tvm_cutlass_fp8_group_gemm); - -TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm_e4m3_e4m3_fp16") - .set_body_typed( - tvm_cutlass_fp8_group_gemm); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def( + "cutlass.group_gemm_e5m2_e5m2_fp16", + tvm_cutlass_fp8_group_gemm) + .def( + "cutlass.group_gemm_e5m2_e4m3_fp16", + tvm_cutlass_fp8_group_gemm) + .def("cutlass.group_gemm_e4m3_e4m3_fp16", + tvm_cutlass_fp8_group_gemm); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu index ffa3ae6653e6..477243e62491 100644 --- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -66,10 +67,14 @@ void tvm_cutlass_fp8_groupwise_scaled_bmm_sm100(NDArray a, NDArray b, NDArray sc a, b, scales_a, scales_b, workspace, block_size_0, block_size_1, out); } -TVM_FFI_REGISTER_GLOBAL("cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn") - .set_body_typed(tvm_cutlass_fp8_groupwise_scaled_gemm_sm100); -TVM_FFI_REGISTER_GLOBAL("cutlass.groupwise_scaled_bmm_e4m3fn_e4m3fn") - .set_body_typed(tvm_cutlass_fp8_groupwise_scaled_bmm_sm100); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn", + tvm_cutlass_fp8_groupwise_scaled_gemm_sm100) + .def("cutlass.groupwise_scaled_bmm_e4m3fn_e4m3fn", + tvm_cutlass_fp8_groupwise_scaled_bmm_sm100); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu index e445e97da364..ee9c9a4d0076 100644 --- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -66,10 +67,13 @@ void tvm_cutlass_fp8_groupwise_scaled_bmm_sm90(NDArray a, NDArray b, NDArray sca a, b, scales_a, scales_b, workspace, block_size_0, block_size_1, out); } -TVM_FFI_REGISTER_GLOBAL("cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn") - .set_body_typed(tvm_cutlass_fp8_groupwise_scaled_gemm_sm90); -TVM_FFI_REGISTER_GLOBAL("cutlass.groupwise_scaled_bmm_e4m3fn_e4m3fn") - .set_body_typed(tvm_cutlass_fp8_groupwise_scaled_bmm_sm90); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn", + tvm_cutlass_fp8_groupwise_scaled_gemm_sm90) + .def("cutlass.groupwise_scaled_bmm_e4m3fn_e4m3fn", tvm_cutlass_fp8_groupwise_scaled_bmm_sm90); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu index d13481e9dd3f..c1371ee90e74 100644 --- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -84,8 +85,11 @@ void tvm_fp8_groupwise_scaled_group_gemm_sm100(NDArray a, NDArray b, NDArray sca } } -TVM_FFI_REGISTER_GLOBAL("cutlass.groupwise_scaled_group_gemm_e4m3fn_e4m3fn") - .set_body_typed(tvm_fp8_groupwise_scaled_group_gemm_sm100); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("cutlass.groupwise_scaled_group_gemm_e4m3fn_e4m3fn", + tvm_fp8_groupwise_scaled_group_gemm_sm100); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/cutlass/weight_preprocess.cc b/src/runtime/contrib/cutlass/weight_preprocess.cc index 5fece6166158..642b5d01f82c 100644 --- a/src/runtime/contrib/cutlass/weight_preprocess.cc +++ b/src/runtime/contrib/cutlass/weight_preprocess.cc @@ -18,6 +18,7 @@ */ #include +#include #include #include "cutlass_kernels/cutlass_preprocessors.h" @@ -34,27 +35,30 @@ namespace runtime { // black box. // // The preprocessing functions are defined in C++, so we need to copy the input weight to CPU. -TVM_FFI_REGISTER_GLOBAL("cutlass.ft_preprocess_weight") - .set_body_typed([](NDArray packed_weight, int sm, bool is_int4) { - bool is_2d = packed_weight->ndim == 2; - int num_experts = is_2d ? 1 : packed_weight->shape[0]; - int rows = packed_weight->shape[is_2d ? 0 : 1]; - int cols = packed_weight->shape[is_2d ? 1 : 2]; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("cutlass.ft_preprocess_weight", [](NDArray packed_weight, int sm, + bool is_int4) { + bool is_2d = packed_weight->ndim == 2; + int num_experts = is_2d ? 1 : packed_weight->shape[0]; + int rows = packed_weight->shape[is_2d ? 0 : 1]; + int cols = packed_weight->shape[is_2d ? 1 : 2]; - std::vector input_cpu(num_experts * rows * cols); - std::vector output_cpu(num_experts * rows * cols); - packed_weight.CopyToBytes(input_cpu.data(), input_cpu.size()); - // multiply cols by 2 since the "col" params in preprocess_weights refers to the column of - // the unpacked weight. - if (is_int4) { - cols *= 2; - } - fastertransformer::preprocess_weights(output_cpu.data(), input_cpu.data(), num_experts, rows, - cols, is_int4, sm); - auto out = NDArray::Empty(packed_weight.Shape(), packed_weight->dtype, packed_weight->device); - out.CopyFromBytes(output_cpu.data(), output_cpu.size()); - return out; - }); + std::vector input_cpu(num_experts * rows * cols); + std::vector output_cpu(num_experts * rows * cols); + packed_weight.CopyToBytes(input_cpu.data(), input_cpu.size()); + // multiply cols by 2 since the "col" params in preprocess_weights refers to the column of + // the unpacked weight. + if (is_int4) { + cols *= 2; + } + fastertransformer::preprocess_weights(output_cpu.data(), input_cpu.data(), num_experts, rows, + cols, is_int4, sm); + auto out = NDArray::Empty(packed_weight.Shape(), packed_weight->dtype, packed_weight->device); + out.CopyFromBytes(output_cpu.data(), output_cpu.size()); + return out; + }); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index 9cc053ec7ca4..1b45d4ddd99b 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -348,39 +349,41 @@ extern "C" void dnnl_binary_op(float* data, float* weight, float* out, int algo_ } // DNNL Conv2d single OP -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.dnnl.conv2d") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto input = args[0].cast(); - auto weights = args[1].cast(); - auto output = args[2].cast(); - int p_Ph0_ = args[3].cast(), p_Pw0_ = args[4].cast(), p_Ph1_ = args[5].cast(), - p_Pw1_ = args[6].cast(), p_Sh_ = args[7].cast(), p_Sw_ = args[8].cast(), - p_G_ = args[9].cast(); - bool channel_last = args[10].cast(); - bool pre_cast = args[11].cast(); - bool post_cast = args[12].cast(); - - int p_N_ = input->shape[0], p_C_ = input->shape[1], p_H_ = input->shape[2], - p_W_ = input->shape[3], p_O_ = output->shape[1], p_Kh_ = weights->shape[2], - p_Kw_ = weights->shape[3]; - - if (channel_last) { - p_N_ = input->shape[0]; - p_H_ = input->shape[1]; - p_W_ = input->shape[2]; - p_C_ = input->shape[3]; - p_O_ = output->shape[3]; - p_Kh_ = weights->shape[0]; - p_Kw_ = weights->shape[1]; - } - - std::vector bias(p_O_, 0); - primitive_attr attr; - return dnnl_conv2d_common( - static_cast(input->data), static_cast(weights->data), bias.data(), - static_cast(output->data), p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, p_Ph0_, p_Pw0_, - p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr, channel_last, pre_cast, post_cast); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("tvm.contrib.dnnl.conv2d", [](ffi::PackedArgs args, ffi::Any* ret) { + auto input = args[0].cast(); + auto weights = args[1].cast(); + auto output = args[2].cast(); + int p_Ph0_ = args[3].cast(), p_Pw0_ = args[4].cast(), p_Ph1_ = args[5].cast(), + p_Pw1_ = args[6].cast(), p_Sh_ = args[7].cast(), p_Sw_ = args[8].cast(), + p_G_ = args[9].cast(); + bool channel_last = args[10].cast(); + bool pre_cast = args[11].cast(); + bool post_cast = args[12].cast(); + + int p_N_ = input->shape[0], p_C_ = input->shape[1], p_H_ = input->shape[2], + p_W_ = input->shape[3], p_O_ = output->shape[1], p_Kh_ = weights->shape[2], + p_Kw_ = weights->shape[3]; + + if (channel_last) { + p_N_ = input->shape[0]; + p_H_ = input->shape[1]; + p_W_ = input->shape[2]; + p_C_ = input->shape[3]; + p_O_ = output->shape[3]; + p_Kh_ = weights->shape[0]; + p_Kw_ = weights->shape[1]; + } + + std::vector bias(p_O_, 0); + primitive_attr attr; + return dnnl_conv2d_common(static_cast(input->data), static_cast(weights->data), + bias.data(), static_cast(output->data), p_N_, p_C_, p_H_, + p_W_, p_O_, p_G_, p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, + p_Sw_, attr, channel_last, pre_cast, post_cast); + }); +}); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index 154ee12790f7..41309612d9b8 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include @@ -927,10 +928,12 @@ runtime::Module DNNLJSONRuntimeCreate(String symbol_name, String graph_json, return runtime::Module(n); } -TVM_FFI_REGISTER_GLOBAL("runtime.DNNLJSONRuntimeCreate").set_body_typed(DNNLJSONRuntimeCreate); - -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_dnnl_json") - .set_body_typed(JSONRuntimeBase::LoadFromBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.DNNLJSONRuntimeCreate", DNNLJSONRuntimeCreate) + .def("runtime.module.loadbinary_dnnl_json", JSONRuntimeBase::LoadFromBinary); +}); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc index 5d706836e6ce..827997ae38b3 100644 --- a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc +++ b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc @@ -27,6 +27,7 @@ #include #include #include +#include namespace tvm { namespace runtime { @@ -68,9 +69,11 @@ Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes, Device dev) { return Module(exec); } -TVM_FFI_REGISTER_GLOBAL("tvm.edgetpu_runtime.create") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = EdgeTPURuntimeCreate(args[0], args[1]); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed( + "tvm.edgetpu_runtime.create", + [](ffi::PackedArgs args, ffi::Any* rv) { *rv = EdgeTPURuntimeCreate(args[0], args[1]); }); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/hipblas/hipblas.cc b/src/runtime/contrib/hipblas/hipblas.cc index fb4e394e7fc2..81d44a441239 100644 --- a/src/runtime/contrib/hipblas/hipblas.cc +++ b/src/runtime/contrib/hipblas/hipblas.cc @@ -21,6 +21,7 @@ * \file Use external hipblas library call. */ #include +#include #include #include @@ -407,51 +408,54 @@ inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, hipblasHandle_t } // matrix multiplication for row major -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.hipblas.matmul") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - auto C = args[2].cast(); - - HipBlasThreadEntry* entry_ptr = HipBlasThreadEntry::ThreadLocal(); - - if (TypeEqual(A->dtype, C->dtype)) { - ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || - TypeMatch(A->dtype, kDLFloat, 64)); - - if (TypeMatch(A->dtype, kDLFloat, 16)) { - CallGemm(args, ret, HipblasHgemmOp(entry_ptr->handle)); - } else if (TypeMatch(A->dtype, kDLFloat, 32)) { - CallGemm(args, ret, HipblasSgemmOp(entry_ptr->handle)); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("tvm.contrib.hipblas.matmul", + [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + auto C = args[2].cast(); + + HipBlasThreadEntry* entry_ptr = HipBlasThreadEntry::ThreadLocal(); + + if (TypeEqual(A->dtype, C->dtype)) { + ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || + TypeMatch(A->dtype, kDLFloat, 32) || + TypeMatch(A->dtype, kDLFloat, 64)); + + if (TypeMatch(A->dtype, kDLFloat, 16)) { + CallGemm(args, ret, HipblasHgemmOp(entry_ptr->handle)); + } else if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallGemm(args, ret, HipblasSgemmOp(entry_ptr->handle)); + } else { + CallGemm(args, ret, HipblasDgemmOp(entry_ptr->handle)); + } + } else { + CallGemmEx(args, ret, entry_ptr->handle); + } + }) + .def_packed("tvm.contrib.hipblas.batch_matmul", [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + auto C = args[2].cast(); + + HipBlasThreadEntry* entry_ptr = HipBlasThreadEntry::ThreadLocal(); + + if (TypeEqual(A->dtype, C->dtype)) { + ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || + TypeMatch(A->dtype, kDLFloat, 64)); + + if (TypeMatch(A->dtype, kDLFloat, 16)) { + CallBatchGemm(args, ret, HipblasHgemmBatchOp(entry_ptr->handle)); + } else if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, HipblasSgemmBatchOp(entry_ptr->handle)); + } else { + CallBatchGemm(args, ret, HipblasDgemmBatchOp(entry_ptr->handle)); + } } else { - CallGemm(args, ret, HipblasDgemmOp(entry_ptr->handle)); + CallBatchGemmEx(args, ret, entry_ptr->handle); } - } else { - CallGemmEx(args, ret, entry_ptr->handle); - } - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.hipblas.batch_matmul") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - auto C = args[2].cast(); - - HipBlasThreadEntry* entry_ptr = HipBlasThreadEntry::ThreadLocal(); - - if (TypeEqual(A->dtype, C->dtype)) { - ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || - TypeMatch(A->dtype, kDLFloat, 64)); - - if (TypeMatch(A->dtype, kDLFloat, 16)) { - CallBatchGemm(args, ret, HipblasHgemmBatchOp(entry_ptr->handle)); - } else if (TypeMatch(A->dtype, kDLFloat, 32)) { - CallBatchGemm(args, ret, HipblasSgemmBatchOp(entry_ptr->handle)); - } else { - CallBatchGemm(args, ret, HipblasDgemmBatchOp(entry_ptr->handle)); - } - } else { - CallBatchGemmEx(args, ret, entry_ptr->handle); - } - }); + }); +}); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc index 60e439125c10..b9e7d1a39275 100644 --- a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc +++ b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include @@ -139,11 +140,13 @@ runtime::Module HipblasJSONRuntimeCreate(String symbol_name, String graph_json, return runtime::Module(n); } -TVM_FFI_REGISTER_GLOBAL("runtime.HipblasJSONRuntimeCreate") - .set_body_typed(HipblasJSONRuntimeCreate); - -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_hipblas_json") - .set_body_typed(JSONRuntimeBase::LoadFromBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.HipblasJSONRuntimeCreate", HipblasJSONRuntimeCreate) + .def("runtime.module.loadbinary_hipblas_json", + JSONRuntimeBase::LoadFromBinary); +}); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/miopen/conv_forward.cc b/src/runtime/contrib/miopen/conv_forward.cc index 247863c56a99..39f52f8263c6 100644 --- a/src/runtime/contrib/miopen/conv_forward.cc +++ b/src/runtime/contrib/miopen/conv_forward.cc @@ -21,6 +21,7 @@ * \file Use external miopen utils function */ #include +#include #include #include @@ -34,191 +35,196 @@ namespace miopen { using namespace runtime; -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - const int mode = args[0].cast(); - const int dtype = args[1].cast(); - const int pad_h = args[2].cast(); - const int pad_w = args[3].cast(); - const int stride_h = args[4].cast(); - const int stride_w = args[5].cast(); - const int dilation_h = args[6].cast(); - const int dilation_w = args[7].cast(); - const int x_dim0 = args[8].cast(); - const int x_dim1 = args[9].cast(); - const int x_dim2 = args[10].cast(); - const int x_dim3 = args[11].cast(); - const int w_dim0 = args[12].cast(); - const int w_dim1 = args[13].cast(); - const int w_dim2 = args[14].cast(); - const int w_dim3 = args[15].cast(); - const int n_group = args[16].cast(); - void* out_shape = args[17].cast(); - - MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); - assert(n_group > 0 && "Group Size > 0 is expected"); - if (n_group > 1) assert(mode > 1 && "Group /Depthwise Conv mode when num of groups > 1"); - // Set Mode - entry_ptr->conv_entry.mode = static_cast(mode); - // Set Device - entry_ptr->conv_entry.device = Device{kDLROCM, 0}; - // Set Data Type - entry_ptr->conv_entry.data_type = - static_cast(dtype); // MIOpen supports fp32(miopenFloat), - // fp16(miopenHalf), int32, int8 at this moment. - // Set Desc - MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.mode, pad_h, pad_w, - stride_h, stride_w, dilation_h, dilation_w)); - if (n_group > 1) - MIOPEN_CALL(miopenSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, n_group)); - // Set Filter - MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc, - entry_ptr->conv_entry.data_type, w_dim0, - w_dim1 / n_group, w_dim2, w_dim3)); - // Set Input - MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.data_type, x_dim0, x_dim1, - x_dim2, x_dim3)); - - // Set Output shape - MIOPEN_CALL(miopenGetConvolutionForwardOutputDim( - entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, static_cast(out_shape), - static_cast(out_shape) + 1, static_cast(out_shape) + 2, - static_cast(out_shape) + 3)); - - const int* oshape = static_cast(out_shape); - // Set Output - MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc, - entry_ptr->conv_entry.data_type, oshape[0], oshape[1], - oshape[2], oshape[3])); - - // Set workspace - size_t workspace_size = 0; - MIOPEN_CALL(miopenConvolutionForwardGetWorkSpaceSize( - entry_ptr->handle, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, &workspace_size)); - entry_ptr->conv_entry.UpdateWorkspace(workspace_size); - - const size_t input_size = x_dim0 * x_dim1 * x_dim2 * x_dim3; - const size_t filter_size = w_dim0 * w_dim1 * w_dim2 * w_dim3; - const size_t output_size = oshape[0] * oshape[1] * oshape[2] * oshape[3]; - - runtime::DeviceAPI* rocm_api = entry_ptr->conv_entry.rocm_api; - float* input_buf = static_cast( - rocm_api->AllocWorkspace(entry_ptr->conv_entry.device, input_size * sizeof(float))); - float* filter_buf = static_cast( - rocm_api->AllocWorkspace(entry_ptr->conv_entry.device, filter_size * sizeof(float))); - float* output_buf = static_cast( - rocm_api->AllocWorkspace(entry_ptr->conv_entry.device, output_size * sizeof(float))); - - const int request_algo_count = 4; - const bool exhaustive_search = false; - void* workspace = entry_ptr->conv_entry.workspace; - if (workspace_size == 0) workspace = nullptr; - int returned_algo_count = 0; - miopenConvAlgoPerf_t perfs[4]; - - MIOPEN_CALL(miopenFindConvolutionForwardAlgorithm( - entry_ptr->handle, entry_ptr->conv_entry.input_desc, input_buf, - entry_ptr->conv_entry.filter_desc, filter_buf, entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.output_desc, output_buf, request_algo_count, &returned_algo_count, - perfs, workspace, workspace_size, exhaustive_search)); - - rocm_api->FreeWorkspace(entry_ptr->conv_entry.device, input_buf); - rocm_api->FreeWorkspace(entry_ptr->conv_entry.device, filter_buf); - rocm_api->FreeWorkspace(entry_ptr->conv_entry.device, output_buf); - - const std::vector fwd_algo_names{ - "miopenConvolutionFwdAlgoGEMM", - "miopenConvolutionFwdAlgoDirect", - "miopenConvolutionFwdAlgoFFT", - "miopenConvolutionFwdAlgoWinograd", - }; - const auto best_algo = perfs[0].fwd_algo; - LOG(INFO) << "\tMIOpen Found " << returned_algo_count << " fwd algorithms, choosing " - << fwd_algo_names[best_algo]; - for (int i = 0; i < returned_algo_count; ++i) { - LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perfs[i].fwd_algo] - << " - time: " << perfs[i].time << " ms" - << ", Memory: " << perfs[i].memory; - } - // Set Algo - ret[0] = static_cast(best_algo); - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - const int mode = args[0].cast(); - const int dtype = args[1].cast(); - const int pad_h = args[2].cast(); - const int pad_w = args[3].cast(); - const int stride_h = args[4].cast(); - const int stride_w = args[5].cast(); - const int dilation_h = args[6].cast(); - const int dilation_w = args[7].cast(); - const int algo = args[8].cast(); - const auto x = args[9].cast(); - const auto w = args[10].cast(); - const auto y = args[11].cast(); - - MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); - entry_ptr->conv_entry.fwd_algo = static_cast(algo); - // Set Mode - entry_ptr->conv_entry.mode = static_cast(mode); - // Set Device - entry_ptr->conv_entry.device = x->device; - // Set Data Type - entry_ptr->conv_entry.data_type = - static_cast(dtype); // MIOpen supports fp32(miopenFloat), - // fp16(miopenHalf) at this moment. - // Set Desc - MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.mode, pad_h, pad_w, - stride_h, stride_w, dilation_h, dilation_w)); - // Set Filter - MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc, - entry_ptr->conv_entry.data_type, w->shape[0], - w->shape[1], w->shape[2], w->shape[3])); - // Set Input - MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.data_type, x->shape[0], - x->shape[1], x->shape[2], x->shape[3])); - // Set Output - MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc, - entry_ptr->conv_entry.data_type, y->shape[0], - y->shape[1], y->shape[2], y->shape[3])); - - // Set workspace - size_t workspace_size = 0; - MIOPEN_CALL(miopenConvolutionForwardGetWorkSpaceSize( - entry_ptr->handle, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, &workspace_size)); - entry_ptr->conv_entry.UpdateWorkspace(workspace_size); - - const float alpha = 1.f; - const float beta = 0.f; - - const int request_algo_count = 4; - const bool exhaustive_search = true; - void* workspace = entry_ptr->conv_entry.workspace; - if (workspace_size == 0) workspace = nullptr; - int returned_algo_count = 0; - miopenConvAlgoPerf_t perfs[4]; - - MIOPEN_CALL(miopenFindConvolutionForwardAlgorithm( - entry_ptr->handle, entry_ptr->conv_entry.input_desc, x->data, - entry_ptr->conv_entry.filter_desc, w->data, entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.output_desc, y->data, request_algo_count, &returned_algo_count, - perfs, workspace, workspace_size, exhaustive_search)); - - MIOPEN_CALL(miopenConvolutionForward( - entry_ptr->handle, &alpha, entry_ptr->conv_entry.input_desc, x->data, - entry_ptr->conv_entry.filter_desc, w->data, entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.fwd_algo, &beta, entry_ptr->conv_entry.output_desc, y->data, - entry_ptr->conv_entry.workspace, entry_ptr->conv_entry.workspace_size)); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed( + "tvm.contrib.miopen.conv2d.setup", + [](ffi::PackedArgs args, ffi::Any* ret) { + const int mode = args[0].cast(); + const int dtype = args[1].cast(); + const int pad_h = args[2].cast(); + const int pad_w = args[3].cast(); + const int stride_h = args[4].cast(); + const int stride_w = args[5].cast(); + const int dilation_h = args[6].cast(); + const int dilation_w = args[7].cast(); + const int x_dim0 = args[8].cast(); + const int x_dim1 = args[9].cast(); + const int x_dim2 = args[10].cast(); + const int x_dim3 = args[11].cast(); + const int w_dim0 = args[12].cast(); + const int w_dim1 = args[13].cast(); + const int w_dim2 = args[14].cast(); + const int w_dim3 = args[15].cast(); + const int n_group = args[16].cast(); + void* out_shape = args[17].cast(); + + MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); + assert(n_group > 0 && "Group Size > 0 is expected"); + if (n_group > 1) + assert(mode > 1 && "Group /Depthwise Conv mode when num of groups > 1"); + // Set Mode + entry_ptr->conv_entry.mode = static_cast(mode); + // Set Device + entry_ptr->conv_entry.device = Device{kDLROCM, 0}; + // Set Data Type + entry_ptr->conv_entry.data_type = static_cast( + dtype); // MIOpen supports fp32(miopenFloat), + // fp16(miopenHalf), int32, int8 at this moment. + // Set Desc + MIOPEN_CALL(miopenInitConvolutionDescriptor( + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.mode, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w)); + if (n_group > 1) + MIOPEN_CALL(miopenSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, n_group)); + // Set Filter + MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc, + entry_ptr->conv_entry.data_type, w_dim0, + w_dim1 / n_group, w_dim2, w_dim3)); + // Set Input + MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.data_type, x_dim0, x_dim1, + x_dim2, x_dim3)); + + // Set Output shape + MIOPEN_CALL(miopenGetConvolutionForwardOutputDim( + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.filter_desc, static_cast(out_shape), + static_cast(out_shape) + 1, static_cast(out_shape) + 2, + static_cast(out_shape) + 3)); + + const int* oshape = static_cast(out_shape); + // Set Output + MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc, + entry_ptr->conv_entry.data_type, oshape[0], + oshape[1], oshape[2], oshape[3])); + + // Set workspace + size_t workspace_size = 0; + MIOPEN_CALL(miopenConvolutionForwardGetWorkSpaceSize( + entry_ptr->handle, entry_ptr->conv_entry.filter_desc, + entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.conv_desc, + entry_ptr->conv_entry.output_desc, &workspace_size)); + entry_ptr->conv_entry.UpdateWorkspace(workspace_size); + + const size_t input_size = x_dim0 * x_dim1 * x_dim2 * x_dim3; + const size_t filter_size = w_dim0 * w_dim1 * w_dim2 * w_dim3; + const size_t output_size = oshape[0] * oshape[1] * oshape[2] * oshape[3]; + + runtime::DeviceAPI* rocm_api = entry_ptr->conv_entry.rocm_api; + float* input_buf = static_cast( + rocm_api->AllocWorkspace(entry_ptr->conv_entry.device, input_size * sizeof(float))); + float* filter_buf = static_cast(rocm_api->AllocWorkspace( + entry_ptr->conv_entry.device, filter_size * sizeof(float))); + float* output_buf = static_cast(rocm_api->AllocWorkspace( + entry_ptr->conv_entry.device, output_size * sizeof(float))); + + const int request_algo_count = 4; + const bool exhaustive_search = false; + void* workspace = entry_ptr->conv_entry.workspace; + if (workspace_size == 0) workspace = nullptr; + int returned_algo_count = 0; + miopenConvAlgoPerf_t perfs[4]; + + MIOPEN_CALL(miopenFindConvolutionForwardAlgorithm( + entry_ptr->handle, entry_ptr->conv_entry.input_desc, input_buf, + entry_ptr->conv_entry.filter_desc, filter_buf, entry_ptr->conv_entry.conv_desc, + entry_ptr->conv_entry.output_desc, output_buf, request_algo_count, + &returned_algo_count, perfs, workspace, workspace_size, exhaustive_search)); + + rocm_api->FreeWorkspace(entry_ptr->conv_entry.device, input_buf); + rocm_api->FreeWorkspace(entry_ptr->conv_entry.device, filter_buf); + rocm_api->FreeWorkspace(entry_ptr->conv_entry.device, output_buf); + + const std::vector fwd_algo_names{ + "miopenConvolutionFwdAlgoGEMM", + "miopenConvolutionFwdAlgoDirect", + "miopenConvolutionFwdAlgoFFT", + "miopenConvolutionFwdAlgoWinograd", + }; + const auto best_algo = perfs[0].fwd_algo; + LOG(INFO) << "\tMIOpen Found " << returned_algo_count << " fwd algorithms, choosing " + << fwd_algo_names[best_algo]; + for (int i = 0; i < returned_algo_count; ++i) { + LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perfs[i].fwd_algo] + << " - time: " << perfs[i].time << " ms" + << ", Memory: " << perfs[i].memory; + } + // Set Algo + ret[0] = static_cast(best_algo); + }) + .def_packed("tvm.contrib.miopen.conv2d.forward", [](ffi::PackedArgs args, ffi::Any* ret) { + const int mode = args[0].cast(); + const int dtype = args[1].cast(); + const int pad_h = args[2].cast(); + const int pad_w = args[3].cast(); + const int stride_h = args[4].cast(); + const int stride_w = args[5].cast(); + const int dilation_h = args[6].cast(); + const int dilation_w = args[7].cast(); + const int algo = args[8].cast(); + const auto x = args[9].cast(); + const auto w = args[10].cast(); + const auto y = args[11].cast(); + + MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); + entry_ptr->conv_entry.fwd_algo = static_cast(algo); + // Set Mode + entry_ptr->conv_entry.mode = static_cast(mode); + // Set Device + entry_ptr->conv_entry.device = x->device; + // Set Data Type + entry_ptr->conv_entry.data_type = + static_cast(dtype); // MIOpen supports fp32(miopenFloat), + // fp16(miopenHalf) at this moment. + // Set Desc + MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc, + entry_ptr->conv_entry.mode, pad_h, pad_w, + stride_h, stride_w, dilation_h, dilation_w)); + // Set Filter + MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc, + entry_ptr->conv_entry.data_type, w->shape[0], + w->shape[1], w->shape[2], w->shape[3])); + // Set Input + MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.data_type, x->shape[0], + x->shape[1], x->shape[2], x->shape[3])); + // Set Output + MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc, + entry_ptr->conv_entry.data_type, y->shape[0], + y->shape[1], y->shape[2], y->shape[3])); + + // Set workspace + size_t workspace_size = 0; + MIOPEN_CALL(miopenConvolutionForwardGetWorkSpaceSize( + entry_ptr->handle, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, &workspace_size)); + entry_ptr->conv_entry.UpdateWorkspace(workspace_size); + + const float alpha = 1.f; + const float beta = 0.f; + + const int request_algo_count = 4; + const bool exhaustive_search = true; + void* workspace = entry_ptr->conv_entry.workspace; + if (workspace_size == 0) workspace = nullptr; + int returned_algo_count = 0; + miopenConvAlgoPerf_t perfs[4]; + + MIOPEN_CALL(miopenFindConvolutionForwardAlgorithm( + entry_ptr->handle, entry_ptr->conv_entry.input_desc, x->data, + entry_ptr->conv_entry.filter_desc, w->data, entry_ptr->conv_entry.conv_desc, + entry_ptr->conv_entry.output_desc, y->data, request_algo_count, &returned_algo_count, + perfs, workspace, workspace_size, exhaustive_search)); + + MIOPEN_CALL(miopenConvolutionForward( + entry_ptr->handle, &alpha, entry_ptr->conv_entry.input_desc, x->data, + entry_ptr->conv_entry.filter_desc, w->data, entry_ptr->conv_entry.conv_desc, + entry_ptr->conv_entry.fwd_algo, &beta, entry_ptr->conv_entry.output_desc, y->data, + entry_ptr->conv_entry.workspace, entry_ptr->conv_entry.workspace_size)); + }); +}); } // namespace miopen } // namespace contrib diff --git a/src/runtime/contrib/miopen/softmax.cc b/src/runtime/contrib/miopen/softmax.cc index 10289f22bdda..def72c8658b7 100644 --- a/src/runtime/contrib/miopen/softmax.cc +++ b/src/runtime/contrib/miopen/softmax.cc @@ -22,6 +22,7 @@ * \brief Use external miopen softmax function */ #include +#include #include #include "miopen_utils.h" @@ -79,15 +80,17 @@ void softmax_impl(ffi::PackedArgs args, ffi::Any* ret, miopenSoftmaxAlgorithm_t entry_ptr->softmax_entry.shape_desc, y->data, alg, mode)); } -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.miopen.softmax.forward") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - softmax_impl(args, ret, MIOPEN_SOFTMAX_ACCURATE); - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.miopen.log_softmax.forward") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - softmax_impl(args, ret, MIOPEN_SOFTMAX_LOG); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("tvm.contrib.miopen.softmax.forward", + [](ffi::PackedArgs args, ffi::Any* ret) { + softmax_impl(args, ret, MIOPEN_SOFTMAX_ACCURATE); + }) + .def_packed( + "tvm.contrib.miopen.log_softmax.forward", + [](ffi::PackedArgs args, ffi::Any* ret) { softmax_impl(args, ret, MIOPEN_SOFTMAX_LOG); }); +}); } // namespace miopen } // namespace contrib diff --git a/src/runtime/contrib/mps/conv.mm b/src/runtime/contrib/mps/conv.mm index dbbb92dd05f7..d7a7ae21b4ce 100644 --- a/src/runtime/contrib/mps/conv.mm +++ b/src/runtime/contrib/mps/conv.mm @@ -17,6 +17,7 @@ * under the License. */ +#include #include "mps_utils.h" namespace tvm { @@ -24,140 +25,143 @@ using namespace runtime; -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto buf = args[0].cast(); - auto img = args[1].cast(); - // copy to temp - id mtlbuf = (__bridge id)(buf->data); - MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); - runtime::metal::MetalThreadEntry* rt = runtime::metal::MetalThreadEntry::ThreadLocal(); - id dev = entry_ptr->metal_api->GetDevice(buf->device); - id temp = rt->GetTempBuffer(buf->device, [mtlbuf length]); - entry_ptr->metal_api->CopyDataFromTo((__bridge void*)mtlbuf, 0, (__bridge void*)temp, 0, - [mtlbuf length], buf -> device, buf -> device, - buf -> dtype, nullptr); - - MPSImageDescriptor* desc = - [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32 - width:buf->shape[2] - height:buf->shape[1] - featureChannels:buf->shape[3]]; - - MPSImage* mpsimg = entry_ptr->AllocMPSImage(dev, desc); - - [mpsimg writeBytes:[temp contents] - dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels - imageIndex:0]; - - img->data = (__bridge void*)mpsimg; - - [mpsimg readBytes:[temp contents] - dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels - imageIndex:0]; - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto img = args[0].cast(); - auto buf = args[1].cast(); - id mtlbuf = (__bridge id)(buf->data); - MPSImage* mpsimg = (__bridge MPSImage*)(img->data); - MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); - runtime::metal::MetalThreadEntry* rt = runtime::metal::MetalThreadEntry::ThreadLocal(); - id temp = rt->GetTempBuffer(buf->device, [mtlbuf length]); - - [mpsimg readBytes:[temp contents] - dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels - imageIndex:0]; - - entry_ptr->metal_api->CopyDataFromTo((__bridge void*)temp, 0, (__bridge void*)mtlbuf, 0, - [mtlbuf length], buf -> device, buf -> device, - buf -> dtype, nullptr); - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mps.conv2d") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - // MPS-NHWC - auto data = args[0].cast(); - auto weight = args[1].cast(); - auto output = args[2].cast(); - int pad = args[3].cast(); - int stride = args[4].cast(); - - ICHECK_EQ(data->ndim, 4); - ICHECK_EQ(weight->ndim, 4); - ICHECK_EQ(output->ndim, 4); - ICHECK(output->strides == nullptr); - ICHECK(weight->strides == nullptr); - ICHECK(data->strides == nullptr); - - ICHECK_EQ(data->shape[0], 1); - ICHECK_EQ(output->shape[0], 1); - - int oCh = weight->shape[0]; - int kH = weight->shape[1]; - int kW = weight->shape[2]; - int iCh = weight->shape[3]; - - const auto f_buf2img = tvm::ffi::Function::GetGlobal("tvm.contrib.mps.buffer2img"); - const auto f_img2buf = tvm::ffi::Function::GetGlobal("tvm.contrib.mps.img2buffer"); - // Get Metal device API - MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); - runtime::metal::MetalThreadEntry* rt = runtime::metal::MetalThreadEntry::ThreadLocal(); - id dev = entry_ptr->metal_api->GetDevice(data->device); - id queue = entry_ptr->metal_api->GetCommandQueue(data->device); - id cb = [queue commandBuffer]; - // data to MPSImage - DLTensor tmp_in; - (*f_buf2img)(data, &tmp_in); - MPSImage* tempA = (__bridge MPSImage*)tmp_in.data; - // weight to temp memory - id bufB = (__bridge id)(weight->data); - id tempB = rt->GetTempBuffer(weight->device, [bufB length]); - entry_ptr->metal_api->CopyDataFromTo((__bridge void*)bufB, 0, (__bridge void*)tempB, 0, - [bufB length], weight -> device, weight -> device, - tmp_in.dtype, nullptr); - float* ptr_w = (float*)[tempB contents]; - // output to MPSImage - DLTensor tmp_out; - (*f_buf2img)(output, &tmp_out); - MPSImage* tempC = (__bridge MPSImage*)tmp_out.data; - // conv desc - - MPSCNNConvolutionDescriptor* conv_desc = - [MPSCNNConvolutionDescriptor cnnConvolutionDescriptorWithKernelWidth:kW - kernelHeight:kH - inputFeatureChannels:iCh - outputFeatureChannels:oCh]; - [conv_desc setStrideInPixelsX:stride]; - [conv_desc setStrideInPixelsY:stride]; - - MPSCNNConvolution* conv = - [[MPSCNNConvolution alloc] initWithDevice:dev - convolutionDescriptor:conv_desc - kernelWeights:ptr_w - biasTerms:nil - flags:MPSCNNConvolutionFlagsNone]; - if (pad == 0) { - conv.padding = [MPSNNDefaultPadding - paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | - MPSNNPaddingMethodAlignCentered | MPSNNPaddingMethodSizeSame]; - } else if (pad == 1) { - conv.padding = [MPSNNDefaultPadding - paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | - MPSNNPaddingMethodAlignCentered | MPSNNPaddingMethodSizeValidOnly]; - } - [conv encodeToCommandBuffer:cb sourceImage:tempA destinationImage:tempC]; - - [cb commit]; - id encoder = [cb blitCommandEncoder]; - [encoder synchronizeResource:tempC.texture]; - [encoder endEncoding]; - [cb waitUntilCompleted]; - - (*f_img2buf)(&tmp_out, output); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("tvm.contrib.mps.buffer2img", + [](ffi::PackedArgs args, ffi::Any* ret) { + auto buf = args[0].cast(); + auto img = args[1].cast(); + // copy to temp + id mtlbuf = (__bridge id)(buf->data); + MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); + runtime::metal::MetalThreadEntry* rt = + runtime::metal::MetalThreadEntry::ThreadLocal(); + id dev = entry_ptr->metal_api->GetDevice(buf->device); + id temp = rt->GetTempBuffer(buf->device, [mtlbuf length]); + entry_ptr->metal_api->CopyDataFromTo( + (__bridge void*)mtlbuf, 0, (__bridge void*)temp, 0, [mtlbuf length], + buf -> device, buf -> device, buf -> dtype, nullptr); + + MPSImageDescriptor* desc = [MPSImageDescriptor + imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32 + width:buf->shape[2] + height:buf->shape[1] + featureChannels:buf->shape[3]]; + + MPSImage* mpsimg = entry_ptr->AllocMPSImage(dev, desc); + + [mpsimg writeBytes:[temp contents] + dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels + imageIndex:0]; + + img->data = (__bridge void*)mpsimg; + + [mpsimg readBytes:[temp contents] + dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels + imageIndex:0]; + }) + .def_packed("tvm.contrib.mps.img2buffer", + [](ffi::PackedArgs args, ffi::Any* ret) { + auto img = args[0].cast(); + auto buf = args[1].cast(); + id mtlbuf = (__bridge id)(buf->data); + MPSImage* mpsimg = (__bridge MPSImage*)(img->data); + MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); + runtime::metal::MetalThreadEntry* rt = + runtime::metal::MetalThreadEntry::ThreadLocal(); + id temp = rt->GetTempBuffer(buf->device, [mtlbuf length]); + + [mpsimg readBytes:[temp contents] + dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels + imageIndex:0]; + + entry_ptr->metal_api->CopyDataFromTo( + (__bridge void*)temp, 0, (__bridge void*)mtlbuf, 0, [mtlbuf length], + buf -> device, buf -> device, buf -> dtype, nullptr); + }) + .def_packed("tvm.contrib.mps.conv2d", [](ffi::PackedArgs args, ffi::Any* ret) { + // MPS-NHWC + auto data = args[0].cast(); + auto weight = args[1].cast(); + auto output = args[2].cast(); + int pad = args[3].cast(); + int stride = args[4].cast(); + + ICHECK_EQ(data->ndim, 4); + ICHECK_EQ(weight->ndim, 4); + ICHECK_EQ(output->ndim, 4); + ICHECK(output->strides == nullptr); + ICHECK(weight->strides == nullptr); + ICHECK(data->strides == nullptr); + + ICHECK_EQ(data->shape[0], 1); + ICHECK_EQ(output->shape[0], 1); + + int oCh = weight->shape[0]; + int kH = weight->shape[1]; + int kW = weight->shape[2]; + int iCh = weight->shape[3]; + + const auto f_buf2img = tvm::ffi::Function::GetGlobal("tvm.contrib.mps.buffer2img"); + const auto f_img2buf = tvm::ffi::Function::GetGlobal("tvm.contrib.mps.img2buffer"); + // Get Metal device API + MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); + runtime::metal::MetalThreadEntry* rt = runtime::metal::MetalThreadEntry::ThreadLocal(); + id dev = entry_ptr->metal_api->GetDevice(data->device); + id queue = entry_ptr->metal_api->GetCommandQueue(data->device); + id cb = [queue commandBuffer]; + // data to MPSImage + DLTensor tmp_in; + (*f_buf2img)(data, &tmp_in); + MPSImage* tempA = (__bridge MPSImage*)tmp_in.data; + // weight to temp memory + id bufB = (__bridge id)(weight->data); + id tempB = rt->GetTempBuffer(weight->device, [bufB length]); + entry_ptr->metal_api->CopyDataFromTo((__bridge void*)bufB, 0, (__bridge void*)tempB, 0, + [bufB length], weight -> device, weight -> device, + tmp_in.dtype, nullptr); + float* ptr_w = (float*)[tempB contents]; + // output to MPSImage + DLTensor tmp_out; + (*f_buf2img)(output, &tmp_out); + MPSImage* tempC = (__bridge MPSImage*)tmp_out.data; + // conv desc + + MPSCNNConvolutionDescriptor* conv_desc = + [MPSCNNConvolutionDescriptor cnnConvolutionDescriptorWithKernelWidth:kW + kernelHeight:kH + inputFeatureChannels:iCh + outputFeatureChannels:oCh]; + [conv_desc setStrideInPixelsX:stride]; + [conv_desc setStrideInPixelsY:stride]; + + MPSCNNConvolution* conv = + [[MPSCNNConvolution alloc] initWithDevice:dev + convolutionDescriptor:conv_desc + kernelWeights:ptr_w + biasTerms:nil + flags:MPSCNNConvolutionFlagsNone]; + if (pad == 0) { + conv.padding = [MPSNNDefaultPadding + paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | + MPSNNPaddingMethodAlignCentered | MPSNNPaddingMethodSizeSame]; + } else if (pad == 1) { + conv.padding = [MPSNNDefaultPadding + paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | + MPSNNPaddingMethodAlignCentered | MPSNNPaddingMethodSizeValidOnly]; + } + [conv encodeToCommandBuffer:cb sourceImage:tempA destinationImage:tempC]; + + [cb commit]; + id encoder = [cb blitCommandEncoder]; + [encoder synchronizeResource:tempC.texture]; + [encoder endEncoding]; + [cb waitUntilCompleted]; + + (*f_img2buf)(&tmp_out, output); + }); +}); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/mps/gemm.mm b/src/runtime/contrib/mps/gemm.mm index 51285251c82e..a44931683495 100644 --- a/src/runtime/contrib/mps/gemm.mm +++ b/src/runtime/contrib/mps/gemm.mm @@ -17,6 +17,7 @@ * under the License. */ +#include #include "mps_utils.h" namespace tvm { @@ -24,75 +25,77 @@ using namespace runtime; -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mps.matmul") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - auto B = args[1].cast(); - auto C = args[2].cast(); - bool transa = args[3].cast(); - bool transb = args[4].cast(); - // call gemm for simple compact code. - ICHECK_EQ(A->ndim, 2); - ICHECK_EQ(B->ndim, 2); - ICHECK_EQ(C->ndim, 2); - ICHECK(C->strides == nullptr); - ICHECK(B->strides == nullptr); - ICHECK(A->strides == nullptr); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); - ICHECK(TypeMatch(B->dtype, kDLFloat, 32)); - ICHECK(TypeMatch(C->dtype, kDLFloat, 32)); - // Get Metal device API - MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); - // ICHECK_EQ(A->device, B->device); - // ICHECK_EQ(A->device, C->device); - id dev = entry_ptr->metal_api->GetDevice(A->device); - id queue = entry_ptr->metal_api->GetCommandQueue(A->device); - id cb = [queue commandBuffer]; - NSUInteger M = A->shape[0 + (transa ? 1 : 0)]; - NSUInteger N = B->shape[1 - (transb ? 1 : 0)]; - NSUInteger K = B->shape[0 + (transb ? 1 : 0)]; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("tvm.contrib.mps.matmul", [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + auto B = args[1].cast(); + auto C = args[2].cast(); + bool transa = args[3].cast(); + bool transb = args[4].cast(); + // call gemm for simple compact code. + ICHECK_EQ(A->ndim, 2); + ICHECK_EQ(B->ndim, 2); + ICHECK_EQ(C->ndim, 2); + ICHECK(C->strides == nullptr); + ICHECK(B->strides == nullptr); + ICHECK(A->strides == nullptr); + ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); + ICHECK(TypeMatch(B->dtype, kDLFloat, 32)); + ICHECK(TypeMatch(C->dtype, kDLFloat, 32)); + // Get Metal device API + MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); + // ICHECK_EQ(A->device, B->device); + // ICHECK_EQ(A->device, C->device); + id dev = entry_ptr->metal_api->GetDevice(A->device); + id queue = entry_ptr->metal_api->GetCommandQueue(A->device); + id cb = [queue commandBuffer]; + NSUInteger M = A->shape[0 + (transa ? 1 : 0)]; + NSUInteger N = B->shape[1 - (transb ? 1 : 0)]; + NSUInteger K = B->shape[0 + (transb ? 1 : 0)]; - ICHECK_EQ(A->shape[1 - (transa ? 1 : 0)], K); - // mps a - MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype); - MPSMatrixDescriptor* descA = - [MPSMatrixDescriptor matrixDescriptorWithDimensions:M - columns:K - rowBytes:K * sizeof(MPSDataTypeFloat32) - dataType:MPSDataTypeFloat32]; - id bufA = (__bridge id)(A->data); - MPSMatrix* matrixA = [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA]; - // mps b - MPSMatrixDescriptor* descB = - [MPSMatrixDescriptor matrixDescriptorWithDimensions:K - columns:N - rowBytes:N * sizeof(dtype) - dataType:dtype]; - id bufB = (__bridge id)(B->data); - MPSMatrix* matrixB = [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB]; - // mps c - MPSMatrixDescriptor* descC = - [MPSMatrixDescriptor matrixDescriptorWithDimensions:M - columns:N - rowBytes:N * sizeof(dtype) - dataType:dtype]; - id bufC = (__bridge id)(C->data); - MPSMatrix* matrixC = [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC]; - // kernel + ICHECK_EQ(A->shape[1 - (transa ? 1 : 0)], K); + // mps a + MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype); + MPSMatrixDescriptor* descA = + [MPSMatrixDescriptor matrixDescriptorWithDimensions:M + columns:K + rowBytes:K * sizeof(MPSDataTypeFloat32) + dataType:MPSDataTypeFloat32]; + id bufA = (__bridge id)(A->data); + MPSMatrix* matrixA = [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA]; + // mps b + MPSMatrixDescriptor* descB = + [MPSMatrixDescriptor matrixDescriptorWithDimensions:K + columns:N + rowBytes:N * sizeof(dtype) + dataType:dtype]; + id bufB = (__bridge id)(B->data); + MPSMatrix* matrixB = [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB]; + // mps c + MPSMatrixDescriptor* descC = + [MPSMatrixDescriptor matrixDescriptorWithDimensions:M + columns:N + rowBytes:N * sizeof(dtype) + dataType:dtype]; + id bufC = (__bridge id)(C->data); + MPSMatrix* matrixC = [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC]; + // kernel - MPSMatrixMultiplication* mul_obj = [[MPSMatrixMultiplication alloc] init]; - MPSMatrixMultiplication* sgemm = [mul_obj initWithDevice:dev - transposeLeft:transa - transposeRight:transb - resultRows:M - resultColumns:N - interiorColumns:K - alpha:1.0f - beta:0.0f]; - ICHECK(sgemm != nil); - [sgemm encodeToCommandBuffer:cb leftMatrix:matrixA rightMatrix:matrixB resultMatrix:matrixC]; - [cb commit]; - }); + MPSMatrixMultiplication* mul_obj = [[MPSMatrixMultiplication alloc] init]; + MPSMatrixMultiplication* sgemm = [mul_obj initWithDevice:dev + transposeLeft:transa + transposeRight:transb + resultRows:M + resultColumns:N + interiorColumns:K + alpha:1.0f + beta:0.0f]; + ICHECK(sgemm != nil); + [sgemm encodeToCommandBuffer:cb leftMatrix:matrixA rightMatrix:matrixB resultMatrix:matrixC]; + [cb commit]; + }); +}); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc b/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc index 3b21ba0e5dc5..4fd1636f6fe1 100644 --- a/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc +++ b/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -476,10 +477,12 @@ bool MarvellHardwareModuleNode::use_dpdk_cb = false; ml_tvmc_cb MarvellHardwareModuleNode::tvmc_cb_ = {}; ml_dpdk_cb MarvellHardwareModuleNode::dpdk_cb_ = {}; -TVM_FFI_REGISTER_GLOBAL("runtime.mrvl_hw_runtime_create") - .set_body_typed(MarvellHardwareModuleRuntimeCreate); -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_mrvl_hw") - .set_body_typed(MarvellHardwareModuleNode::LoadFromBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.mrvl_hw_runtime_create", MarvellHardwareModuleRuntimeCreate) + .def("runtime.module.loadbinary_mrvl_hw", MarvellHardwareModuleNode::LoadFromBinary); +}); } // namespace contrib } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/mrvl/mrvl_runtime.cc b/src/runtime/contrib/mrvl/mrvl_runtime.cc index 701ae6ed8dcd..ee1f953974a9 100644 --- a/src/runtime/contrib/mrvl/mrvl_runtime.cc +++ b/src/runtime/contrib/mrvl/mrvl_runtime.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -149,10 +150,12 @@ runtime::Module MarvellSimulatorModuleRuntimeCreate(const String& symbol_name, return runtime::Module(n); } -TVM_FFI_REGISTER_GLOBAL("runtime.mrvl_runtime_create") - .set_body_typed(MarvellSimulatorModuleRuntimeCreate); -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_mrvl_sim") - .set_body_typed(MarvellSimulatorModuleNode::LoadFromBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.mrvl_runtime_create", MarvellSimulatorModuleRuntimeCreate) + .def("runtime.module.loadbinary_mrvl_sim", MarvellSimulatorModuleNode::LoadFromBinary); +}); } // namespace contrib } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/msc/tensorrt_runtime.cc b/src/runtime/contrib/msc/tensorrt_runtime.cc index 8819cfd2fc4a..99efa2085012 100644 --- a/src/runtime/contrib/msc/tensorrt_runtime.cc +++ b/src/runtime/contrib/msc/tensorrt_runtime.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -348,11 +349,13 @@ runtime::Module MSCTensorRTRuntimeCreate(const String& symbol_name, const String return runtime::Module(n); } -TVM_FFI_REGISTER_GLOBAL("runtime.msc_tensorrt_runtime_create") - .set_body_typed(MSCTensorRTRuntimeCreate); - -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_msc_tensorrt") - .set_body_typed(JSONRuntimeBase::LoadFromBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.msc_tensorrt_runtime_create", MSCTensorRTRuntimeCreate) + .def("runtime.module.loadbinary_msc_tensorrt", + JSONRuntimeBase::LoadFromBinary); +}); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/mscclpp/allreduce.cu b/src/runtime/contrib/mscclpp/allreduce.cu index 66a6a097f650..2b009c062585 100644 --- a/src/runtime/contrib/mscclpp/allreduce.cu +++ b/src/runtime/contrib/mscclpp/allreduce.cu @@ -17,9 +17,8 @@ * under the License. */ -#include -#include #include +#include #include "msccl.cuh" diff --git a/src/runtime/contrib/nnapi/nnapi_runtime.cc b/src/runtime/contrib/nnapi/nnapi_runtime.cc index 0fcf9fded0a8..83faf354b400 100644 --- a/src/runtime/contrib/nnapi/nnapi_runtime.cc +++ b/src/runtime/contrib/nnapi/nnapi_runtime.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -240,10 +241,12 @@ runtime::Module NNAPIRuntimeCreate(const String& symbol_name, const String& grap return runtime::Module(n); } -TVM_FFI_REGISTER_GLOBAL("runtime.nnapi_runtime_create").set_body_typed(NNAPIRuntimeCreate); - -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_nnapi") - .set_body_typed(JSONRuntimeBase::LoadFromBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.nnapi_runtime_create", NNAPIRuntimeCreate) + .def("runtime.module.loadbinary_nnapi", JSONRuntimeBase::LoadFromBinary); +}); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/nvshmem/init.cc b/src/runtime/contrib/nvshmem/init.cc index 6dea2281f714..494b4d5e4db1 100644 --- a/src/runtime/contrib/nvshmem/init.cc +++ b/src/runtime/contrib/nvshmem/init.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "../../cuda/cuda_common.h" @@ -118,14 +119,14 @@ void NVSHMEMXCumoduleInit(void* cuModule) { } } -TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_uid").set_body_typed(InitNVSHMEMUID); - -TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM); - -TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_wrapper") - .set_body_typed(InitNVSHMEMWrapper); - -TVM_FFI_REGISTER_GLOBAL("runtime.nvshmem.cumodule_init").set_body_typed(NVSHMEMXCumoduleInit); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.disco.nvshmem.init_nvshmem_uid", InitNVSHMEMUID) + .def("runtime.disco.nvshmem.init_nvshmem", InitNVSHMEM) + .def("runtime.disco.nvshmem.init_nvshmem_wrapper", InitNVSHMEMWrapper) + .def("runtime.nvshmem.cumodule_init", NVSHMEMXCumoduleInit); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/nvshmem/kv_transfer.cu b/src/runtime/contrib/nvshmem/kv_transfer.cu index 2dad73707df7..59bd26b5dab3 100644 --- a/src/runtime/contrib/nvshmem/kv_transfer.cu +++ b/src/runtime/contrib/nvshmem/kv_transfer.cu @@ -19,9 +19,10 @@ #include #include #include +#include +#include #include #include -#include template __device__ int64_t calc_flattened_index(int shape[dim], int index[dim]) { @@ -329,5 +330,9 @@ int _KVTransferPageToPage(DLTensor* remote_pages, DLTensor* local_pages, return 0; } -TVM_FFI_REGISTER_GLOBAL("nvshmem.KVTransfer").set_body_typed(_KVTransfer); -TVM_FFI_REGISTER_GLOBAL("nvshmem.KVTransferPageToPage").set_body_typed(_KVTransferPageToPage); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("nvshmem.KVTransfer", _KVTransfer) + .def("nvshmem.KVTransferPageToPage", _KVTransferPageToPage); +}); diff --git a/src/runtime/contrib/nvshmem/memory_allocator.cc b/src/runtime/contrib/nvshmem/memory_allocator.cc index 770db2f90227..f651c67aaa41 100644 --- a/src/runtime/contrib/nvshmem/memory_allocator.cc +++ b/src/runtime/contrib/nvshmem/memory_allocator.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -89,14 +90,20 @@ NDArray NVSHMEMEmpty(ffi::Shape shape, DataType dtype, Device device) { return NVSHMEMAllocator::Global()->Empty(shape, dtype, UseDefaultDeviceIfNone(device)); } -TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.empty").set_body_typed(NVSHMEMEmpty); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("runtime.disco.nvshmem.empty", NVSHMEMEmpty); +}); void NVSHMEMFinalize() { NVSHMEMAllocator::Global()->Clear(); nvshmem_finalize(); } -TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.finalize_nvshmem").set_body_typed(NVSHMEMFinalize); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("runtime.disco.nvshmem.finalize_nvshmem", NVSHMEMFinalize); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/papi/papi.cc b/src/runtime/contrib/papi/papi.cc index 882cee36b246..670bf9a27682 100644 --- a/src/runtime/contrib/papi/papi.cc +++ b/src/runtime/contrib/papi/papi.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include @@ -290,10 +291,12 @@ MetricCollector CreatePAPIMetricCollector(Map> metr TVM_REGISTER_OBJECT_TYPE(PAPIEventSetNode); TVM_REGISTER_OBJECT_TYPE(PAPIMetricCollectorNode); -TVM_FFI_REGISTER_GLOBAL("runtime.profiling.PAPIMetricCollector") - .set_body_typed([](Map> metrics) { - return PAPIMetricCollector(metrics); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "runtime.profiling.PAPIMetricCollector", + [](Map> metrics) { return PAPIMetricCollector(metrics); }); +}); } // namespace profiling } // namespace runtime diff --git a/src/runtime/contrib/random/random.cc b/src/runtime/contrib/random/random.cc index 8f05a7241b02..9c1ba3e32733 100644 --- a/src/runtime/contrib/random/random.cc +++ b/src/runtime/contrib/random/random.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include #include @@ -69,78 +70,79 @@ RandomThreadLocalEntry* RandomThreadLocalEntry::ThreadLocal() { return RandomThreadLocalStore::Get(); } -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.random.randint") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); - int64_t low = args[0].cast(); - int64_t high = args[1].cast(); - auto out = args[2].cast(); - ICHECK_GT(high, low) << "high must be bigger than low"; - ICHECK(out->strides == nullptr); - - DLDataType dtype = out->dtype; - int64_t size = 1; - for (int i = 0; i < out->ndim; ++i) { - size *= out->shape[i]; - } - - DLPACK_INTEGER_TYPE_SWITCH(dtype, DType, { - int64_t numeric_low = std::numeric_limits::min(); - int64_t numeric_high = std::numeric_limits::max(); - numeric_high += 1; // exclusive upper bound - low = std::max(low, numeric_low); - high = std::min(high, numeric_high); - - if (out->device.device_type == kDLCPU) { - // file the data with random byte - std::generate_n(static_cast(out->data), size, [&]() { - unsigned rint = entry->random_engine.GetRandInt(); - return low + rint % (high - low); - }); - } else { - LOG(FATAL) << "Do not support random.randint on this device yet"; - } - }) - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.random.uniform") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); - double low = args[0].cast(); - double high = args[1].cast(); - auto out = args[2].cast(); - entry->random_engine.SampleUniform(out, low, high); - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.random.normal") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); - double loc = args[0].cast(); - double scale = args[1].cast(); - auto out = args[2].cast(); - entry->random_engine.SampleNormal(out, loc, scale); - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.random.random_fill") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); - auto out = args[0].cast(); - entry->random_engine.RandomFill(out); - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.random.random_fill_for_measure") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) -> void { - const auto curand = tvm::ffi::Function::GetGlobal("runtime.contrib.curand.RandomFill"); - auto out = args[0].cast(); - if (curand.has_value() && out->device.device_type == DLDeviceType::kDLCUDA) { - if (out->dtype.code == DLDataTypeCode::kDLFloat) { - (*curand)(out); - return; - } - } - RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); - entry->random_engine.RandomFillForMeasure(out); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("tvm.contrib.random.randint", + [](ffi::PackedArgs args, ffi::Any* ret) { + RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); + int64_t low = args[0].cast(); + int64_t high = args[1].cast(); + auto out = args[2].cast(); + ICHECK_GT(high, low) << "high must be bigger than low"; + ICHECK(out->strides == nullptr); + + DLDataType dtype = out->dtype; + int64_t size = 1; + for (int i = 0; i < out->ndim; ++i) { + size *= out->shape[i]; + } + + DLPACK_INTEGER_TYPE_SWITCH(dtype, DType, { + int64_t numeric_low = std::numeric_limits::min(); + int64_t numeric_high = std::numeric_limits::max(); + numeric_high += 1; // exclusive upper bound + low = std::max(low, numeric_low); + high = std::min(high, numeric_high); + + if (out->device.device_type == kDLCPU) { + // file the data with random byte + std::generate_n(static_cast(out->data), size, [&]() { + unsigned rint = entry->random_engine.GetRandInt(); + return low + rint % (high - low); + }); + } else { + LOG(FATAL) << "Do not support random.randint on this device yet"; + } + }) + }) + .def_packed("tvm.contrib.random.uniform", + [](ffi::PackedArgs args, ffi::Any* ret) { + RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); + double low = args[0].cast(); + double high = args[1].cast(); + auto out = args[2].cast(); + entry->random_engine.SampleUniform(out, low, high); + }) + .def_packed("tvm.contrib.random.normal", + [](ffi::PackedArgs args, ffi::Any* ret) { + RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); + double loc = args[0].cast(); + double scale = args[1].cast(); + auto out = args[2].cast(); + entry->random_engine.SampleNormal(out, loc, scale); + }) + .def_packed("tvm.contrib.random.random_fill", + [](ffi::PackedArgs args, ffi::Any* ret) { + RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); + auto out = args[0].cast(); + entry->random_engine.RandomFill(out); + }) + .def_packed("tvm.contrib.random.random_fill_for_measure", + [](ffi::PackedArgs args, ffi::Any* ret) -> void { + const auto curand = + tvm::ffi::Function::GetGlobal("runtime.contrib.curand.RandomFill"); + auto out = args[0].cast(); + if (curand.has_value() && out->device.device_type == DLDeviceType::kDLCUDA) { + if (out->dtype.code == DLDataTypeCode::kDLFloat) { + (*curand)(out); + return; + } + } + RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); + entry->random_engine.RandomFillForMeasure(out); + }); +}); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/rocblas/rocblas.cc b/src/runtime/contrib/rocblas/rocblas.cc index 2969d7fd0e5e..42fe8425d1df 100644 --- a/src/runtime/contrib/rocblas/rocblas.cc +++ b/src/runtime/contrib/rocblas/rocblas.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -65,78 +66,85 @@ struct RocBlasThreadEntry { typedef dmlc::ThreadLocalStore RocBlasThreadStore; // matrix multiplication for row major -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - auto B = args[1].cast(); - auto C = args[2].cast(); - bool transa = args[3].cast(); - bool transb = args[4].cast(); - // call gemm for simple compact code. - ICHECK_EQ(A->ndim, 2); - ICHECK_EQ(B->ndim, 2); - ICHECK_EQ(C->ndim, 2); - ICHECK(C->strides == nullptr); - ICHECK(B->strides == nullptr); - ICHECK(A->strides == nullptr); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); - ICHECK(TypeMatch(B->dtype, kDLFloat, 32)); - ICHECK(TypeMatch(C->dtype, kDLFloat, 32)); - - float alpha = 1.0; - float beta = 0.0; - float* A_ptr = reinterpret_cast(static_cast(A->data) + A->byte_offset); - float* B_ptr = reinterpret_cast(static_cast(B->data) + B->byte_offset); - float* C_ptr = reinterpret_cast(static_cast(C->data) + C->byte_offset); - - rocblas_operation roc_trans_A = transa ? rocblas_operation_transpose : rocblas_operation_none; - rocblas_operation roc_trans_B = transb ? rocblas_operation_transpose : rocblas_operation_none; - size_t N = transb ? B->shape[0] : B->shape[1]; - size_t M = transa ? A->shape[1] : A->shape[0]; - size_t K = transb ? B->shape[1] : B->shape[0]; - size_t lda = transa ? M : K; - size_t ldb = transb ? K : N; - size_t ldc = N; - - CHECK_ROCBLAS_ERROR(rocblas_sgemm(RocBlasThreadStore::Get()->handle, roc_trans_B, roc_trans_A, - N, M, K, &alpha, B_ptr, ldb, A_ptr, lda, &beta, C_ptr, - ldc)); - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.rocblas.batch_matmul") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - auto B = args[1].cast(); - auto C = args[2].cast(); - bool transa = args[3].cast(); - bool transb = args[4].cast(); - // call gemm for simple compact code. - ICHECK_EQ(A->ndim, 3); - ICHECK_EQ(B->ndim, 3); - ICHECK_EQ(C->ndim, 3); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); - ICHECK(TypeMatch(B->dtype, kDLFloat, 32)); - ICHECK(TypeMatch(C->dtype, kDLFloat, 32)); - - float alpha = 1.0; - float beta = 0.0; - float* A_ptr = reinterpret_cast(static_cast(A->data) + A->byte_offset); - float* B_ptr = reinterpret_cast(static_cast(B->data) + B->byte_offset); - float* C_ptr = reinterpret_cast(static_cast(C->data) + C->byte_offset); - - rocblas_operation roc_trans_A = transa ? rocblas_operation_transpose : rocblas_operation_none; - rocblas_operation roc_trans_B = transb ? rocblas_operation_transpose : rocblas_operation_none; - size_t batch_size = C->shape[0]; - size_t N = transb ? B->shape[1] : B->shape[2]; - size_t M = transa ? A->shape[2] : A->shape[1]; - size_t K = transb ? B->shape[2] : B->shape[1]; - size_t lda = transa ? M : K; - size_t ldb = transb ? K : N; - size_t ldc = N; - - CHECK_ROCBLAS_ERROR(rocblas_sgemm_strided_batched( - RocBlasThreadStore::Get()->handle, roc_trans_B, roc_trans_A, N, M, K, &alpha, B_ptr, ldb, - K * N, A_ptr, lda, M * K, &beta, C_ptr, ldc, M * N, batch_size)); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed( + "tvm.contrib.rocblas.matmul", + [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + auto B = args[1].cast(); + auto C = args[2].cast(); + bool transa = args[3].cast(); + bool transb = args[4].cast(); + // call gemm for simple compact code. + ICHECK_EQ(A->ndim, 2); + ICHECK_EQ(B->ndim, 2); + ICHECK_EQ(C->ndim, 2); + ICHECK(C->strides == nullptr); + ICHECK(B->strides == nullptr); + ICHECK(A->strides == nullptr); + ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); + ICHECK(TypeMatch(B->dtype, kDLFloat, 32)); + ICHECK(TypeMatch(C->dtype, kDLFloat, 32)); + + float alpha = 1.0; + float beta = 0.0; + float* A_ptr = reinterpret_cast(static_cast(A->data) + A->byte_offset); + float* B_ptr = reinterpret_cast(static_cast(B->data) + B->byte_offset); + float* C_ptr = reinterpret_cast(static_cast(C->data) + C->byte_offset); + + rocblas_operation roc_trans_A = + transa ? rocblas_operation_transpose : rocblas_operation_none; + rocblas_operation roc_trans_B = + transb ? rocblas_operation_transpose : rocblas_operation_none; + size_t N = transb ? B->shape[0] : B->shape[1]; + size_t M = transa ? A->shape[1] : A->shape[0]; + size_t K = transb ? B->shape[1] : B->shape[0]; + size_t lda = transa ? M : K; + size_t ldb = transb ? K : N; + size_t ldc = N; + + CHECK_ROCBLAS_ERROR(rocblas_sgemm(RocBlasThreadStore::Get()->handle, roc_trans_B, + roc_trans_A, N, M, K, &alpha, B_ptr, ldb, A_ptr, lda, + &beta, C_ptr, ldc)); + }) + .def_packed("tvm.contrib.rocblas.batch_matmul", [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + auto B = args[1].cast(); + auto C = args[2].cast(); + bool transa = args[3].cast(); + bool transb = args[4].cast(); + // call gemm for simple compact code. + ICHECK_EQ(A->ndim, 3); + ICHECK_EQ(B->ndim, 3); + ICHECK_EQ(C->ndim, 3); + ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); + ICHECK(TypeMatch(B->dtype, kDLFloat, 32)); + ICHECK(TypeMatch(C->dtype, kDLFloat, 32)); + + float alpha = 1.0; + float beta = 0.0; + float* A_ptr = reinterpret_cast(static_cast(A->data) + A->byte_offset); + float* B_ptr = reinterpret_cast(static_cast(B->data) + B->byte_offset); + float* C_ptr = reinterpret_cast(static_cast(C->data) + C->byte_offset); + + rocblas_operation roc_trans_A = + transa ? rocblas_operation_transpose : rocblas_operation_none; + rocblas_operation roc_trans_B = + transb ? rocblas_operation_transpose : rocblas_operation_none; + size_t batch_size = C->shape[0]; + size_t N = transb ? B->shape[1] : B->shape[2]; + size_t M = transa ? A->shape[2] : A->shape[1]; + size_t K = transb ? B->shape[2] : B->shape[1]; + size_t lda = transa ? M : K; + size_t ldb = transb ? K : N; + size_t ldc = N; + + CHECK_ROCBLAS_ERROR(rocblas_sgemm_strided_batched( + RocBlasThreadStore::Get()->handle, roc_trans_B, roc_trans_A, N, M, K, &alpha, B_ptr, + ldb, K * N, A_ptr, lda, M * K, &beta, C_ptr, ldc, M * N, batch_size)); + }); +}); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index 62639e684055..54b17b9b0ea5 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -23,6 +23,7 @@ #include #include +#include #include #include @@ -79,81 +80,84 @@ struct float16 { // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto input = args[0].cast(); - auto sort_num = args[1].cast(); - auto output = args[2].cast(); - int32_t axis = args[3].cast(); - bool is_ascend = args[4].cast(); - - auto dtype = input->dtype; - auto data_ptr = static_cast(input->data); - auto sort_num_ptr = static_cast(sort_num->data); - std::vector> sorter; - int64_t axis_mul_before = 1; - int64_t axis_mul_after = 1; - - if (axis < 0) { - axis = input->ndim + axis; - } +void RegisterArgsortNMS() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed( + "tvm.contrib.sort.argsort_nms", [](ffi::PackedArgs args, ffi::Any* ret) { + auto input = args[0].cast(); + auto sort_num = args[1].cast(); + auto output = args[2].cast(); + int32_t axis = args[3].cast(); + bool is_ascend = args[4].cast(); + + auto dtype = input->dtype; + auto data_ptr = static_cast(input->data); + auto sort_num_ptr = static_cast(sort_num->data); + std::vector> sorter; + int64_t axis_mul_before = 1; + int64_t axis_mul_after = 1; + + if (axis < 0) { + axis = input->ndim + axis; + } - // Currently only supports input dtype to be float32. - ICHECK_EQ(dtype.code, 2) << "Currently only supports input dtype " - "to be float."; + // Currently only supports input dtype to be float32. + ICHECK_EQ(dtype.code, 2) << "Currently only supports input dtype " + "to be float."; #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC != 1) - ICHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype " - "to be float32."; + ICHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype " + "to be float32."; #endif - ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " - "input ndim " - << input->ndim; - - for (int i = 0; i < input->ndim; ++i) { - if (i < axis) { - axis_mul_before *= input->shape[i]; - } else if (i > axis) { - axis_mul_after *= input->shape[i]; + ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " + "input ndim " + << input->ndim; + + for (int i = 0; i < input->ndim; ++i) { + if (i < axis) { + axis_mul_before *= input->shape[i]; + } else if (i > axis) { + axis_mul_after *= input->shape[i]; + } } - } - for (int64_t i = 0; i < axis_mul_before; ++i) { - for (int64_t j = 0; j < axis_mul_after; ++j) { - sorter.clear(); - int32_t current_sort_num = *(sort_num_ptr + i * axis_mul_after + j); - int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; - for (int64_t k = 0; k < current_sort_num; ++k) { - int64_t full_idx = base_idx + k * axis_mul_after; - sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx))); - } - if (is_ascend) { + for (int64_t i = 0; i < axis_mul_before; ++i) { + for (int64_t j = 0; j < axis_mul_after; ++j) { + sorter.clear(); + int32_t current_sort_num = *(sort_num_ptr + i * axis_mul_after + j); + int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; + for (int64_t k = 0; k < current_sort_num; ++k) { + int64_t full_idx = base_idx + k * axis_mul_after; + sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx))); + } + if (is_ascend) { #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) - if (dtype.bits == 16) { - std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<__fp16>); - } else { + if (dtype.bits == 16) { + std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<__fp16>); + } else { #endif - std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); + std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) - } + } #endif - } else { -#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) - if (dtype.bits == 16) { - std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<__fp16>); } else { +#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) + if (dtype.bits == 16) { + std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<__fp16>); + } else { #endif - std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); + std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) - } + } #endif - } - for (int32_t k = 0; k < input->shape[axis]; ++k) { - *(static_cast(output->data) + base_idx + k * axis_mul_after) = - k < static_cast(sorter.size()) ? sorter[k].first : k; + } + for (int32_t k = 0; k < input->shape[axis]; ++k) { + *(static_cast(output->data) + base_idx + k * axis_mul_after) = + k < static_cast(sorter.size()) ? sorter[k].first : k; + } } } - } - }); + }); +} template void sort_impl( @@ -218,94 +222,96 @@ void sort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.sort.argsort") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto input = args[0].cast(); - auto output = args[1].cast(); - int32_t axis = args[2].cast(); - bool is_ascend = args[3].cast(); - if (axis < 0) { - axis = input->ndim + axis; +void RegisterArgsort() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("tvm.contrib.sort.argsort", [](ffi::PackedArgs args, ffi::Any* ret) { + auto input = args[0].cast(); + auto output = args[1].cast(); + int32_t axis = args[2].cast(); + bool is_ascend = args[3].cast(); + if (axis < 0) { + axis = input->ndim + axis; + } + ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " + "input ndim " + << input->ndim; + + auto data_dtype = ffi::DLDataTypeToString(input->dtype); + auto out_dtype = ffi::DLDataTypeToString(output->dtype); + + if (data_dtype == "float32") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "float64") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } - ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " - "input ndim " - << input->ndim; - - auto data_dtype = ffi::DLDataTypeToString(input->dtype); - auto out_dtype = ffi::DLDataTypeToString(output->dtype); - - if (data_dtype == "float32") { - if (out_dtype == "int32") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "int64") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "float32") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "float64") { - argsort(input, output, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } - } else if (data_dtype == "float64") { - if (out_dtype == "int32") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "int64") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "float32") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "float64") { - argsort(input, output, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) - } else if (data_dtype == "float16") { - if (out_dtype == "float16") { - argsort<__fp16, __fp16>(input, output, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } + } else if (data_dtype == "float16") { + if (out_dtype == "float16") { + argsort<__fp16, __fp16>(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } #endif - } else if (data_dtype == "int32") { - if (out_dtype == "int32") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "int64") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "float32") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "float64") { - argsort(input, output, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } - } else if (data_dtype == "int64") { - if (out_dtype == "int32") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "int64") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "float32") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "float64") { - argsort(input, output, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } - } else if (data_dtype == "float16") { - if (out_dtype == "int32") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "int64") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "float32") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "float64") { - argsort(input, output, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } + } else if (data_dtype == "int32") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "int64") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); } else { - LOG(FATAL) << "Unsupported input dtype: " << data_dtype; + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } - }); + } else if (data_dtype == "float16") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else { + LOG(FATAL) << "Unsupported input dtype: " << data_dtype; + } + }); +} // Sort implemented C library sort. // Return sorted tensor. @@ -314,42 +320,44 @@ TVM_FFI_REGISTER_GLOBAL("tvm.contrib.sort.argsort") // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.sort.sort") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto input = args[0].cast(); - auto output = args[1].cast(); - int32_t axis = args[2].cast(); - bool is_ascend = args[3].cast(); - if (axis < 0) { - axis = input->ndim + axis; - } - ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " - "input ndim " - << input->ndim; +void RegisterSort() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("tvm.contrib.sort.sort", [](ffi::PackedArgs args, ffi::Any* ret) { + auto input = args[0].cast(); + auto output = args[1].cast(); + int32_t axis = args[2].cast(); + bool is_ascend = args[3].cast(); + if (axis < 0) { + axis = input->ndim + axis; + } + ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " + "input ndim " + << input->ndim; - auto data_dtype = DLDataTypeToString(input->dtype); - auto out_dtype = DLDataTypeToString(output->dtype); + auto data_dtype = DLDataTypeToString(input->dtype); + auto out_dtype = DLDataTypeToString(output->dtype); - ICHECK_EQ(data_dtype, out_dtype); + ICHECK_EQ(data_dtype, out_dtype); - if (data_dtype == "float32") { - sort(input, output, axis, is_ascend); - } else if (data_dtype == "float64") { - sort(input, output, axis, is_ascend); + if (data_dtype == "float32") { + sort(input, output, axis, is_ascend); + } else if (data_dtype == "float64") { + sort(input, output, axis, is_ascend); #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) - } else if (data_dtype == "float16") { - sort<__fp16>(input, output, axis, is_ascend); + } else if (data_dtype == "float16") { + sort<__fp16>(input, output, axis, is_ascend); #endif - } else if (data_dtype == "int32") { - sort(input, output, axis, is_ascend); - } else if (data_dtype == "int64") { - sort(input, output, axis, is_ascend); - } else if (data_dtype == "float16") { - sort(input, output, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported input dtype: " << data_dtype; - } - }); + } else if (data_dtype == "int32") { + sort(input, output, axis, is_ascend); + } else if (data_dtype == "int64") { + sort(input, output, axis, is_ascend); + } else if (data_dtype == "float16") { + sort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported input dtype: " << data_dtype; + } + }); +} template void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, int axis, @@ -444,127 +452,136 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.sort.topk") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto input = args[0].cast(); - DLTensor* values_out = nullptr; - DLTensor* indices_out = nullptr; - int k = args[args.size() - 4].cast(); - int axis = args[args.size() - 3].cast(); - std::string ret_type = args[args.size() - 2].cast(); - bool is_ascend = args[args.size() - 1].cast(); - if (ret_type == "both") { - values_out = args[1].cast(); - indices_out = args[2].cast(); - } else if (ret_type == "values") { - values_out = args[1].cast(); - } else if (ret_type == "indices") { - indices_out = args[1].cast(); +void RegisterTopk() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("tvm.contrib.sort.topk", [](ffi::PackedArgs args, ffi::Any* ret) { + auto input = args[0].cast(); + DLTensor* values_out = nullptr; + DLTensor* indices_out = nullptr; + int k = args[args.size() - 4].cast(); + int axis = args[args.size() - 3].cast(); + std::string ret_type = args[args.size() - 2].cast(); + bool is_ascend = args[args.size() - 1].cast(); + if (ret_type == "both") { + values_out = args[1].cast(); + indices_out = args[2].cast(); + } else if (ret_type == "values") { + values_out = args[1].cast(); + } else if (ret_type == "indices") { + indices_out = args[1].cast(); + } else { + LOG(FATAL) << "Unsupported ret type: " << ret_type; + } + if (axis < 0) { + axis = input->ndim + axis; + } + ICHECK(axis >= 0 && axis < input->ndim) + << "Axis out of boundary for input ndim " << input->ndim; + + auto data_dtype = ffi::DLDataTypeToString(input->dtype); + auto out_dtype = + (indices_out == nullptr) ? "int64" : ffi::DLDataTypeToString(indices_out->dtype); + + if (data_dtype == "float32") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); } else { - LOG(FATAL) << "Unsupported ret type: " << ret_type; + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } - if (axis < 0) { - axis = input->ndim + axis; + } else if (data_dtype == "float64") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } - ICHECK(axis >= 0 && axis < input->ndim) - << "Axis out of boundary for input ndim " << input->ndim; - - auto data_dtype = ffi::DLDataTypeToString(input->dtype); - auto out_dtype = - (indices_out == nullptr) ? "int64" : ffi::DLDataTypeToString(indices_out->dtype); - - if (data_dtype == "float32") { - if (out_dtype == "int32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "int64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } - } else if (data_dtype == "float64") { - if (out_dtype == "int32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "int64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } - } else if (data_dtype == "uint8") { - if (out_dtype == "uint8") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "int32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "int64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } - } else if (data_dtype == "int8") { - if (out_dtype == "int8") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "int32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "int64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } - } else if (data_dtype == "int32") { - if (out_dtype == "int32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "int64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } - } else if (data_dtype == "int64") { - if (out_dtype == "int32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "int64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } - } else if (data_dtype == "float16") { - if (out_dtype == "int32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "int64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } + } else if (data_dtype == "uint8") { + if (out_dtype == "uint8") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "int8") { + if (out_dtype == "int8") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "int32") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); } else { - LOG(FATAL) << "Unsupported input dtype: " << data_dtype; + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } - }); + } else if (data_dtype == "int64") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "float16") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else { + LOG(FATAL) << "Unsupported input dtype: " << data_dtype; + } + }); +} + +TVM_FFI_STATIC_INIT_BLOCK({ + RegisterArgsortNMS(); + RegisterArgsort(); + RegisterSort(); + RegisterTopk(); +}); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index a8bd43127258..07ce7a016c47 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -524,10 +525,12 @@ runtime::Module TensorRTRuntimeCreate(const String& symbol_name, const String& g return runtime::Module(n); } -TVM_FFI_REGISTER_GLOBAL("runtime.tensorrt_runtime_create").set_body_typed(TensorRTRuntimeCreate); - -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_tensorrt") - .set_body_typed(JSONRuntimeBase::LoadFromBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.tensorrt_runtime_create", TensorRTRuntimeCreate) + .def("runtime.module.loadbinary_tensorrt", JSONRuntimeBase::LoadFromBinary); +}); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index 74cfcad3e650..8df356b04716 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -26,6 +26,7 @@ #include #include #include +#include namespace tvm { namespace runtime { @@ -183,11 +184,14 @@ Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, Device dev) { return Module(exec); } -TVM_FFI_REGISTER_GLOBAL("tvm.tflite_runtime.create") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = TFLiteRuntimeCreate(args[0].cast(), args[1].cast()); - }); - -TVM_FFI_REGISTER_GLOBAL("target.runtime.tflite").set_body_typed(TFLiteRuntimeCreate); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("tvm.tflite_runtime.create", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = TFLiteRuntimeCreate(args[0].cast(), args[1].cast()); + }) + .def("target.runtime.tflite", TFLiteRuntimeCreate); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 6b6b9df834ab..94651c37f136 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -33,6 +33,7 @@ #include #include #include +#include #include #include @@ -233,25 +234,27 @@ void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices } } -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.sort") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - ICHECK_GE(args.size(), 4); - auto input = args[0].cast(); - auto values_out = args[1].cast(); - auto indices_out = args[2].cast(); - bool is_ascend = args[3].cast(); - DLTensor* workspace = nullptr; - if (args.size() == 5) { - workspace = args[4].cast(); - } - - auto data_dtype = ffi::DLDataTypeToString(input->dtype); - auto out_dtype = ffi::DLDataTypeToString(indices_out->dtype); - - int n_values = input->shape[input->ndim - 1]; - thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype, - workspace); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("tvm.contrib.thrust.sort", [](ffi::PackedArgs args, ffi::Any* ret) { + ICHECK_GE(args.size(), 4); + auto input = args[0].cast(); + auto values_out = args[1].cast(); + auto indices_out = args[2].cast(); + bool is_ascend = args[3].cast(); + DLTensor* workspace = nullptr; + if (args.size() == 5) { + workspace = args[4].cast(); + } + + auto data_dtype = ffi::DLDataTypeToString(input->dtype); + auto out_dtype = ffi::DLDataTypeToString(indices_out->dtype); + + int n_values = input->shape[input->ndim - 1]; + thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype, + workspace); + }); +}); template void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* keys_out, @@ -280,65 +283,68 @@ void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* thrust::stable_sort_by_key(policy, keys_out_ptr, keys_out_ptr + size, values_out_ptr); } -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - ICHECK_GE(args.size(), 5); - auto keys_in = args[0].cast(); - auto values_in = args[1].cast(); - auto keys_out = args[2].cast(); - auto values_out = args[3].cast(); - bool for_scatter = args[4].cast(); - DLTensor* workspace = nullptr; - if (args.size() == 6) { - workspace = args[5].cast(); - } - - auto key_dtype = ffi::DLDataTypeToString(keys_in->dtype); - auto value_dtype = ffi::DLDataTypeToString(values_in->dtype); - - if (key_dtype == "int32") { - if (value_dtype == "int32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, for_scatter, - workspace); - } else if (value_dtype == "int64") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter, workspace); - } else if (value_dtype == "float32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter, workspace); - } else { - LOG(FATAL) << "Unsupported value dtype: " << value_dtype; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed( + "tvm.contrib.thrust.stable_sort_by_key", [](ffi::PackedArgs args, ffi::Any* ret) { + ICHECK_GE(args.size(), 5); + auto keys_in = args[0].cast(); + auto values_in = args[1].cast(); + auto keys_out = args[2].cast(); + auto values_out = args[3].cast(); + bool for_scatter = args[4].cast(); + DLTensor* workspace = nullptr; + if (args.size() == 6) { + workspace = args[5].cast(); } - } else if (key_dtype == "int64") { - if (value_dtype == "int32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter, workspace); - } else if (value_dtype == "int64") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter, workspace); - } else if (value_dtype == "float32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter, workspace); - } else { - LOG(FATAL) << "Unsupported value dtype: " << value_dtype; - } - } else if (key_dtype == "float32") { - if (value_dtype == "int32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + + auto key_dtype = ffi::DLDataTypeToString(keys_in->dtype); + auto value_dtype = ffi::DLDataTypeToString(values_in->dtype); + + if (key_dtype == "int32") { + if (value_dtype == "int32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, for_scatter, workspace); - } else if (value_dtype == "int64") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + } else if (value_dtype == "int64") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, for_scatter, workspace); - } else if (value_dtype == "float32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + } else if (value_dtype == "float32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, for_scatter, workspace); + } else { + LOG(FATAL) << "Unsupported value dtype: " << value_dtype; + } + } else if (key_dtype == "int64") { + if (value_dtype == "int32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter, workspace); + } else if (value_dtype == "int64") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter, workspace); + } else if (value_dtype == "float32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter, workspace); + } else { + LOG(FATAL) << "Unsupported value dtype: " << value_dtype; + } + } else if (key_dtype == "float32") { + if (value_dtype == "int32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter, workspace); + } else if (value_dtype == "int64") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter, workspace); + } else if (value_dtype == "float32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter, workspace); + } else { + LOG(FATAL) << "Unsupported value dtype: " << value_dtype; + } } else { - LOG(FATAL) << "Unsupported value dtype: " << value_dtype; + LOG(FATAL) << "Unsupported key dtype: " << key_dtype; } - } else { - LOG(FATAL) << "Unsupported key dtype: " << key_dtype; - } - }); + }); +}); template void thrust_scan(DLTensor* data, DLTensor* output, bool exclusive, DLTensor* workspace) { @@ -395,83 +401,86 @@ void thrust_scan(DLTensor* data, DLTensor* output, bool exclusive, DLTensor* wor } } -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - ICHECK(args.size() == 2 || args.size() == 3 || args.size() == 4); - auto data = args[0].cast(); - auto output = args[1].cast(); - bool exclusive = false; - DLTensor* workspace = nullptr; - - if (args.size() >= 3) { - exclusive = args[2].cast(); - } - - if (args.size() == 4) { - workspace = args[3].cast(); - } - - auto in_dtype = ffi::DLDataTypeToString(data->dtype); - auto out_dtype = ffi::DLDataTypeToString(output->dtype); - - if (in_dtype == "bool") { - if (out_dtype == "int32") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "int64") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float32") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float64") { - thrust_scan(data, output, exclusive, workspace); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are int32, int64, float32, and float64"; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed( + "tvm.contrib.thrust.sum_scan", [](ffi::PackedArgs args, ffi::Any* ret) { + ICHECK(args.size() == 2 || args.size() == 3 || args.size() == 4); + auto data = args[0].cast(); + auto output = args[1].cast(); + bool exclusive = false; + DLTensor* workspace = nullptr; + + if (args.size() >= 3) { + exclusive = args[2].cast(); } - } else if (in_dtype == "int32") { - if (out_dtype == "int32") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "int64") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float32") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float64") { - thrust_scan(data, output, exclusive, workspace); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are int32, int64, float32, and float64"; - } - } else if (in_dtype == "int64") { - if (out_dtype == "int64") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float32") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float64") { - thrust_scan(data, output, exclusive, workspace); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are int64, float32, and float64"; - } - } else if (in_dtype == "float32") { - if (out_dtype == "float32") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float64") { - thrust_scan(data, output, exclusive, workspace); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are float32, and float64"; + + if (args.size() == 4) { + workspace = args[3].cast(); } - } else if (in_dtype == "float64") { - if (out_dtype == "float64") { - thrust_scan(data, output, exclusive, workspace); + + auto in_dtype = ffi::DLDataTypeToString(data->dtype); + auto out_dtype = ffi::DLDataTypeToString(output->dtype); + + if (in_dtype == "bool") { + if (out_dtype == "int32") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "int64") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float32") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float64") { + thrust_scan(data, output, exclusive, workspace); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are int32, int64, float32, and float64"; + } + } else if (in_dtype == "int32") { + if (out_dtype == "int32") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "int64") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float32") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float64") { + thrust_scan(data, output, exclusive, workspace); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are int32, int64, float32, and float64"; + } + } else if (in_dtype == "int64") { + if (out_dtype == "int64") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float32") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float64") { + thrust_scan(data, output, exclusive, workspace); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are int64, float32, and float64"; + } + } else if (in_dtype == "float32") { + if (out_dtype == "float32") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float64") { + thrust_scan(data, output, exclusive, workspace); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are float32, and float64"; + } + } else if (in_dtype == "float64") { + if (out_dtype == "float64") { + thrust_scan(data, output, exclusive, workspace); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtype is float64"; + } } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtype is float64"; + LOG(FATAL) << "Unsupported input dtype: " << in_dtype + << ". Supported input dtypes are bool, int32, int64, float32, and float64"; } - } else { - LOG(FATAL) << "Unsupported input dtype: " << in_dtype - << ". Supported input dtypes are bool, int32, int64, float32, and float64"; - } - }); + }); +}); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/vllm/attention_kernels.cu b/src/runtime/contrib/vllm/attention_kernels.cu index 9221f4672511..5e0fe5f24e8d 100644 --- a/src/runtime/contrib/vllm/attention_kernels.cu +++ b/src/runtime/contrib/vllm/attention_kernels.cu @@ -18,9 +18,9 @@ */ #include -#include -#include #include +#include +#include #include #include @@ -735,35 +735,42 @@ void single_query_cached_kv_attention_v2( } } -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention") - .set_body_typed([](const DLTensor* query, const DLTensor* key_cache, - const DLTensor* value_cache, const DLTensor* block_tables, - const DLTensor* context_lens, int block_size, - const DLTensor* max_context_len_tensor, // TODO(masahi): pass integer - DLTensor* exp_sums, DLTensor* max_logits, DLTensor* tmp_out, DLTensor* out) { - int num_seqs = query->shape[0]; - int num_heads = query->shape[1]; - int max_context_len = static_cast(max_context_len_tensor->data)[0]; - const int PARTITION_SIZE = 512; - int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); - bool use_v1 = - max_context_len <= 8192 && (max_num_partitions == 1 || num_seqs * num_heads > 512); - if (use_v1) { - single_query_cached_kv_attention_v1(query, key_cache, value_cache, block_tables, - context_lens, block_size, max_context_len_tensor, out); - } else { - single_query_cached_kv_attention_v2(query, key_cache, value_cache, block_tables, - context_lens, block_size, max_context_len_tensor, - exp_sums, max_logits, tmp_out, out); - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "tvm.contrib.vllm.single_query_cached_kv_attention", + [](const DLTensor* query, const DLTensor* key_cache, const DLTensor* value_cache, + const DLTensor* block_tables, const DLTensor* context_lens, int block_size, + const DLTensor* max_context_len_tensor, // TODO(masahi): pass integer + DLTensor* exp_sums, DLTensor* max_logits, DLTensor* tmp_out, DLTensor* out) { + int num_seqs = query->shape[0]; + int num_heads = query->shape[1]; + int max_context_len = static_cast(max_context_len_tensor->data)[0]; + const int PARTITION_SIZE = 512; + int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + bool use_v1 = + max_context_len <= 8192 && (max_num_partitions == 1 || num_seqs * num_heads > 512); + if (use_v1) { + single_query_cached_kv_attention_v1(query, key_cache, value_cache, block_tables, + context_lens, block_size, max_context_len_tensor, + out); + } else { + single_query_cached_kv_attention_v2(query, key_cache, value_cache, block_tables, + context_lens, block_size, max_context_len_tensor, + exp_sums, max_logits, tmp_out, out); + } + }); +}); // Expose for testing -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention_v1") - .set_body_typed(single_query_cached_kv_attention_v1); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention_v2") - .set_body_typed(single_query_cached_kv_attention_v2); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tvm.contrib.vllm.single_query_cached_kv_attention_v1", + single_query_cached_kv_attention_v1) + .def("tvm.contrib.vllm.single_query_cached_kv_attention_v2", + single_query_cached_kv_attention_v2); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/vllm/cache_alloc.cc b/src/runtime/contrib/vllm/cache_alloc.cc index dd2b7bd5bb37..042224d54874 100644 --- a/src/runtime/contrib/vllm/cache_alloc.cc +++ b/src/runtime/contrib/vllm/cache_alloc.cc @@ -18,6 +18,7 @@ */ #include #include +#include #include namespace tvm { @@ -48,7 +49,10 @@ Array AllocateKVCache(int head_size, int num_layers, int num_heads, int return cache; } -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.allocate_kv_cache").set_body_typed(AllocateKVCache); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tvm.contrib.vllm.allocate_kv_cache", AllocateKVCache); +}); } // namespace vllm } // namespace runtime diff --git a/src/runtime/contrib/vllm/cache_kernels.cu b/src/runtime/contrib/vllm/cache_kernels.cu index 01320daac650..13f21971e846 100644 --- a/src/runtime/contrib/vllm/cache_kernels.cu +++ b/src/runtime/contrib/vllm/cache_kernels.cu @@ -16,9 +16,9 @@ * specific language governing permissions and limitations * under the License. */ -#include -#include #include +#include +#include #include #include @@ -130,105 +130,107 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, int64_t* value_cache namespace tvm { namespace runtime { -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.reshape_and_cache") - .set_body_typed([](NDArray key, NDArray value, NDArray key_cache, NDArray value_cache, - NDArray slot_mapping) { - int num_tokens = key->shape[0]; - int num_heads = key->shape[1]; - int head_size = key->shape[2]; - int block_size = key_cache->shape[3]; - int vec_size = key_cache->shape[4]; - - int key_stride = key->shape[1] * key->shape[2]; - int value_stride = value->shape[1] * value->shape[2]; - - dim3 grid(num_tokens); - dim3 block(std::min(num_heads * head_size, 512)); - - using scalar_t = uint16_t; - vllm::reshape_and_cache_kernel<<>>( - static_cast(key->data), static_cast(value->data), - static_cast(key_cache->data), static_cast(value_cache->data), - static_cast(slot_mapping->data), key_stride, value_stride, num_heads, - head_size, block_size, vec_size); - - return Array{key_cache, value_cache}; - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.reconstruct_from_cache") - .set_body_typed([](NDArray key_cache, NDArray value_cache, NDArray slot_mapping) { - int num_tokens = slot_mapping->shape[0]; - int num_heads = value_cache->shape[1]; - int head_size = value_cache->shape[2]; - int block_size = value_cache->shape[3]; - int vec_size = key_cache->shape[4]; - - DLDevice dev = key_cache->device; - auto key = NDArray::Empty({num_tokens, num_heads, head_size}, key_cache->dtype, dev); - auto value = NDArray::Empty({num_tokens, num_heads, head_size}, key_cache->dtype, dev); - - int key_stride = key->shape[1] * key->shape[2]; - int value_stride = value->shape[1] * value->shape[2]; - - dim3 grid(num_tokens); - dim3 block(std::min(num_heads * head_size, 512)); - - using scalar_t = uint16_t; - vllm::reconstruct_from_cache_kernel - <<>>(static_cast(key_cache->data), - static_cast(value_cache->data), - static_cast(slot_mapping->data), - static_cast(key->data), static_cast(value->data), - key_stride, value_stride, num_heads, head_size, block_size, vec_size); - - return Array{key, value}; - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.copy_blocks") - .set_body_typed([](Array key_value_caches, NDArray block_mapping) { - auto num_layers = key_value_caches.size() / 2; - auto num_pairs = block_mapping->shape[0] / 2; - - if (num_layers == 0) { - return; - } - - std::vector key_cache_ptrs(num_layers); - std::vector value_cache_ptrs(num_layers); - for (size_t layer_idx = 0; layer_idx < num_layers; ++layer_idx) { - key_cache_ptrs[layer_idx] = - reinterpret_cast(key_value_caches[2 * layer_idx]->data); - value_cache_ptrs[layer_idx] = - reinterpret_cast(key_value_caches[2 * layer_idx + 1]->data); - } - - NDArray key_cache = key_value_caches[1]; // [num_blocks, num_heads, head_size, block_size] - DLDevice dev = key_cache->device; - - NDArray key_cache_ptrs_gpu = - NDArray::Empty({static_cast(num_layers)}, runtime::DataType::Int(64), dev); - NDArray value_cache_ptrs_gpu = - NDArray::Empty({static_cast(num_layers)}, runtime::DataType::Int(64), dev); - key_cache_ptrs_gpu.CopyFromBytes(key_cache_ptrs.data(), - sizeof(int64_t) * key_cache_ptrs.size()); - value_cache_ptrs_gpu.CopyFromBytes(value_cache_ptrs.data(), - sizeof(int64_t) * value_cache_ptrs.size()); - - NDArray block_mapping_gpu = - NDArray::Empty(block_mapping.Shape(), runtime::DataType::Int(64), dev); - block_mapping_gpu.CopyFromBytes(block_mapping->data, - sizeof(int64_t) * block_mapping->shape[0]); - - const int numel_per_block = key_cache->shape[1] * key_cache->shape[2] * key_cache->shape[3]; - dim3 grid(num_layers, num_pairs); - dim3 block(std::min(1024, numel_per_block)); - - using scalar_t = uint16_t; - vllm::copy_blocks_kernel - <<>>(static_cast(key_cache_ptrs_gpu->data), - static_cast(value_cache_ptrs_gpu->data), - static_cast(block_mapping_gpu->data), numel_per_block); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tvm.contrib.vllm.reshape_and_cache", + [](NDArray key, NDArray value, NDArray key_cache, NDArray value_cache, + NDArray slot_mapping) { + int num_tokens = key->shape[0]; + int num_heads = key->shape[1]; + int head_size = key->shape[2]; + int block_size = key_cache->shape[3]; + int vec_size = key_cache->shape[4]; + + int key_stride = key->shape[1] * key->shape[2]; + int value_stride = value->shape[1] * value->shape[2]; + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + + using scalar_t = uint16_t; + vllm::reshape_and_cache_kernel<<>>( + static_cast(key->data), static_cast(value->data), + static_cast(key_cache->data), static_cast(value_cache->data), + static_cast(slot_mapping->data), key_stride, value_stride, num_heads, + head_size, block_size, vec_size); + + return Array{key_cache, value_cache}; + }) + .def("tvm.contrib.vllm.reconstruct_from_cache", + [](NDArray key_cache, NDArray value_cache, NDArray slot_mapping) { + int num_tokens = slot_mapping->shape[0]; + int num_heads = value_cache->shape[1]; + int head_size = value_cache->shape[2]; + int block_size = value_cache->shape[3]; + int vec_size = key_cache->shape[4]; + + DLDevice dev = key_cache->device; + auto key = NDArray::Empty({num_tokens, num_heads, head_size}, key_cache->dtype, dev); + auto value = NDArray::Empty({num_tokens, num_heads, head_size}, key_cache->dtype, dev); + + int key_stride = key->shape[1] * key->shape[2]; + int value_stride = value->shape[1] * value->shape[2]; + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + + using scalar_t = uint16_t; + vllm::reconstruct_from_cache_kernel<<>>( + static_cast(key_cache->data), + static_cast(value_cache->data), + static_cast(slot_mapping->data), static_cast(key->data), + static_cast(value->data), key_stride, value_stride, num_heads, + head_size, block_size, vec_size); + + return Array{key, value}; + }) + .def("tvm.contrib.vllm.copy_blocks", [](Array key_value_caches, + NDArray block_mapping) { + auto num_layers = key_value_caches.size() / 2; + auto num_pairs = block_mapping->shape[0] / 2; + + if (num_layers == 0) { + return; + } + + std::vector key_cache_ptrs(num_layers); + std::vector value_cache_ptrs(num_layers); + for (size_t layer_idx = 0; layer_idx < num_layers; ++layer_idx) { + key_cache_ptrs[layer_idx] = + reinterpret_cast(key_value_caches[2 * layer_idx]->data); + value_cache_ptrs[layer_idx] = + reinterpret_cast(key_value_caches[2 * layer_idx + 1]->data); + } + + NDArray key_cache = key_value_caches[1]; // [num_blocks, num_heads, head_size, block_size] + DLDevice dev = key_cache->device; + + NDArray key_cache_ptrs_gpu = + NDArray::Empty({static_cast(num_layers)}, runtime::DataType::Int(64), dev); + NDArray value_cache_ptrs_gpu = + NDArray::Empty({static_cast(num_layers)}, runtime::DataType::Int(64), dev); + key_cache_ptrs_gpu.CopyFromBytes(key_cache_ptrs.data(), + sizeof(int64_t) * key_cache_ptrs.size()); + value_cache_ptrs_gpu.CopyFromBytes(value_cache_ptrs.data(), + sizeof(int64_t) * value_cache_ptrs.size()); + + NDArray block_mapping_gpu = + NDArray::Empty(block_mapping.Shape(), runtime::DataType::Int(64), dev); + block_mapping_gpu.CopyFromBytes(block_mapping->data, + sizeof(int64_t) * block_mapping->shape[0]); + + const int numel_per_block = key_cache->shape[1] * key_cache->shape[2] * key_cache->shape[3]; + dim3 grid(num_layers, num_pairs); + dim3 block(std::min(1024, numel_per_block)); + + using scalar_t = uint16_t; + vllm::copy_blocks_kernel + <<>>(static_cast(key_cache_ptrs_gpu->data), + static_cast(value_cache_ptrs_gpu->data), + static_cast(block_mapping_gpu->data), numel_per_block); + }); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc index 68594f0769fe..7f92ac9fa814 100644 --- a/src/runtime/cpu_device_api.cc +++ b/src/runtime/cpu_device_api.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include @@ -150,9 +151,12 @@ void CPUDeviceAPI::FreeWorkspace(Device dev, void* data) { dmlc::ThreadLocalStore::Get()->FreeWorkspace(dev, data); } -TVM_FFI_REGISTER_GLOBAL("device_api.cpu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = CPUDeviceAPI::Global(); - *rv = static_cast(ptr); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("device_api.cpu", [](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = CPUDeviceAPI::Global(); + *rv = static_cast(ptr); + }); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 91702d035482..b7fc28e5feb5 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -286,17 +287,20 @@ CUDAThreadEntry::CUDAThreadEntry() : pool(kDLCUDA, CUDADeviceAPI::Global()) {} CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { return CUDAThreadStore::Get(); } -TVM_FFI_REGISTER_GLOBAL("device_api.cuda").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = CUDADeviceAPI::Global(); - *rv = static_cast(ptr); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("device_api.cuda", + [](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = CUDADeviceAPI::Global(); + *rv = static_cast(ptr); + }) + .def_packed("device_api.cuda_host", [](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = CUDADeviceAPI::Global(); + *rv = static_cast(ptr); + }); }); -TVM_FFI_REGISTER_GLOBAL("device_api.cuda_host") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = CUDADeviceAPI::Global(); - *rv = static_cast(ptr); - }); - class CUDATimerNode : public TimerNode { public: virtual void Start() { @@ -330,8 +334,10 @@ class CUDATimerNode : public TimerNode { TVM_REGISTER_OBJECT_TYPE(CUDATimerNode); -TVM_FFI_REGISTER_GLOBAL("profiling.timer.cuda").set_body_typed([](Device dev) { - return Timer(make_object()); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("profiling.timer.cuda", + [](Device dev) { return Timer(make_object()); }); }); TVM_DLL String GetCudaFreeMemory() { @@ -343,10 +349,12 @@ TVM_DLL String GetCudaFreeMemory() { return ss.str(); } -TVM_FFI_REGISTER_GLOBAL("runtime.GetCudaFreeMemory").set_body_typed(GetCudaFreeMemory); - -TVM_FFI_REGISTER_GLOBAL("runtime.get_cuda_stream").set_body_typed([]() { - return static_cast(CUDAThreadEntry::ThreadLocal()->stream); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.GetCudaFreeMemory", GetCudaFreeMemory) + .def("runtime.get_cuda_stream", + []() { return static_cast(CUDAThreadEntry::ThreadLocal()->stream); }); }); TVM_DLL int GetCudaDeviceCount() { @@ -355,7 +363,10 @@ TVM_DLL int GetCudaDeviceCount() { return count; } -TVM_FFI_REGISTER_GLOBAL("runtime.GetCudaDeviceCount").set_body_typed(GetCudaDeviceCount); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("runtime.GetCudaDeviceCount", GetCudaDeviceCount); +}); #if (CUDA_VERSION >= 12000) /** @@ -381,188 +392,191 @@ TVM_FFI_REGISTER_GLOBAL("runtime.GetCudaDeviceCount").set_body_typed(GetCudaDevi * \param l2_promotion_kind (int): An integer corresponding to the CUtensorMapL2promotion enum. * \param oob_fill_kind (int): An integer corresponding to the CUtensorMapFloatOOBfill enum. */ -TVM_FFI_REGISTER_GLOBAL("runtime.cuTensorMapEncodeTiled") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - CHECK_GE(args.size(), 4) << "init_cuTensorMap expects at least 4 arguments"; - size_t arg_cnt = 0; - CUtensorMap* tensor_map = static_cast(args[arg_cnt++].cast()); - runtime::DataType tensor_dtype = args[arg_cnt++].cast(); - uint32_t tensor_rank = static_cast(args[arg_cnt++].cast()); - void* tensor_ptr = static_cast(args[arg_cnt++].cast()); - - CHECK_EQ(args.size(), 4 + tensor_rank * 4 + 3) - << "cuTensorMapEncodeTiled expects " << 4 + tensor_rank * 4 + 3 << " arguments" - << "tensor_map, tensor_dtype, tensor_rank, tensor_ptr, global_shape(" << tensor_rank - << "), global_strides(" << tensor_rank - 1 << "), shared_shape(" << tensor_rank - << "), shared_strides(" << tensor_rank << "), interleaved_kind, swizzle_kind" - << ", l2_promotion_kind, oob_fill_kind"; - - std::vector global_shape(tensor_rank); - std::vector global_strides(tensor_rank); - std::vector shared_shape(tensor_rank); - std::vector shared_strides(tensor_rank); - for (size_t i = 0; i < tensor_rank; ++i) { - global_shape[i] = static_cast(args[arg_cnt++].cast()); - } - for (size_t i = 0; i < tensor_rank - 1; ++i) { - global_strides[i] = static_cast(args[arg_cnt++].cast()); - CHECK_EQ(global_strides[i] % 16, 0) << "global strides must be multiple of 16"; - } - for (size_t i = 0; i < tensor_rank; ++i) { - shared_shape[i] = static_cast(args[arg_cnt++].cast()); - CHECK_GE(shared_shape[i], 0) << "boxDim must be non-negative"; - CHECK_LE(shared_shape[i], 256) << "boxDim must be less than or equal to 256"; - } - for (size_t i = 0; i < tensor_rank; ++i) { - shared_strides[i] = static_cast(args[arg_cnt++].cast()); - } - auto interleaved_kind = static_cast(args[arg_cnt++].cast()); - auto swizzle_kind = static_cast(args[arg_cnt++].cast()); - auto l2_promotion_kind = static_cast(args[arg_cnt++].cast()); - auto oob_fill_kind = static_cast(args[arg_cnt++].cast()); - - ICHECK_EQ(tensor_dtype.lanes(), 1) - << "Expect tensor_dtype to have lanes=1, but get " << tensor_dtype; - CUtensorMapDataType cu_dtype; - switch (tensor_dtype.code()) { - case DataType::kInt: - // int - switch (tensor_dtype.bits()) { - case 8: - cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; - break; - case 32: - cu_dtype = CU_TENSOR_MAP_DATA_TYPE_INT32; - break; - case 64: - cu_dtype = CU_TENSOR_MAP_DATA_TYPE_INT64; - break; - default: - LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); - } - break; - case DataType::kUInt: - // unsigned int - switch (tensor_dtype.bits()) { - case 8: - cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; - break; - case 16: - cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT16; - break; - case 32: - cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT32; - break; - case 64: - cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT64; - break; - default: - LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); - } - break; - case DataType::kFloat: - // float - switch (tensor_dtype.bits()) { - case 16: - cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - break; - case 32: - cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT32; - break; - case 64: - cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT64; - break; - default: - LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); - } - break; - case DataType::kBFloat: - // bfloat - switch (tensor_dtype.bits()) { - case 16: - cu_dtype = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; - break; - default: - LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); - } - break; - case DataType::kFloat8_e4m3fn: - // NV float8 e4m3 - cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; - break; - case DataType::kFloat8_e5m2: - // NV float8 e5m2 - cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; - break; - default: - LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); - } - - // sanity checks per cuTensorMapEncodeTiled requirements - // see - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 - CHECK_EQ((reinterpret_cast(tensor_ptr) & 0b1111), 0); // 16-byte alignment - CHECK_EQ((reinterpret_cast(tensor_map) & 0b111111), 0); // 64-byte alignment - CHECK_LE(tensor_rank, 5) << "cuTensorMapEncodeTiled only supports up to 5D tensors"; - - if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_32B) { - CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 32) - << "CU_TENSOR_MAP_SWIZZLE_32B implies the bounding box inner dimension will be <= 32."; - } else if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_64B) { - CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 64) - << "CU_TENSOR_MAP_SWIZZLE_64B implies the bounding box inner dimension will be <= 64."; - } else if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_128B) { - CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 128) - << "CU_TENSOR_MAP_SWIZZLE_128B implies the bounding box inner dimension will be <= " - "128."; - } - - const cuuint64_t* global_shape_ptr = global_shape.data(); - const cuuint64_t* global_strides_ptr = global_strides.data(); - const uint32_t* shared_shape_ptr = shared_shape.data(); - const uint32_t* shared_strides_ptr = shared_strides.data(); - - CUresult res = - cuTensorMapEncodeTiled(tensor_map, cu_dtype, tensor_rank, tensor_ptr, global_shape_ptr, - global_strides_ptr, shared_shape_ptr, shared_strides_ptr, - interleaved_kind, swizzle_kind, l2_promotion_kind, oob_fill_kind); - const char* errstr; - cuGetErrorString(res, &errstr); - if (res != CUDA_SUCCESS) { - // get error string - const char* error_string = nullptr; - cuGetErrorString(res, &error_string); - std::cerr << "Error in cuTensorMapEncodeTiled: " << error_string << std::endl; - std::cout << "cu_dtype: " << cu_dtype << "\n"; - std::cout << "TMA Desc Addr: " << tensor_map << "\n"; - std::cout << "TMA Interleave: " << interleaved_kind << "\n"; - std::cout << "TMA L2Promotion: " << l2_promotion_kind << "\n"; - std::cout << "TMA OOBFill: " << oob_fill_kind << "\n"; - std::cout << "SMEM Swizzle: " << swizzle_kind << "\n"; - std::cout << "tensor rank: " << tensor_rank << "\n"; - std::cout << "global prob shape: "; - for (size_t i = 0; i < tensor_rank; i++) { - std::cout << global_shape[i] << " "; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("runtime.cuTensorMapEncodeTiled", [](ffi::PackedArgs args, + ffi::Any* rv) { + CHECK_GE(args.size(), 4) << "init_cuTensorMap expects at least 4 arguments"; + size_t arg_cnt = 0; + CUtensorMap* tensor_map = static_cast(args[arg_cnt++].cast()); + runtime::DataType tensor_dtype = args[arg_cnt++].cast(); + uint32_t tensor_rank = static_cast(args[arg_cnt++].cast()); + void* tensor_ptr = static_cast(args[arg_cnt++].cast()); + + CHECK_EQ(args.size(), 4 + tensor_rank * 4 + 3) + << "cuTensorMapEncodeTiled expects " << 4 + tensor_rank * 4 + 3 << " arguments" + << "tensor_map, tensor_dtype, tensor_rank, tensor_ptr, global_shape(" << tensor_rank + << "), global_strides(" << tensor_rank - 1 << "), shared_shape(" << tensor_rank + << "), shared_strides(" << tensor_rank << "), interleaved_kind, swizzle_kind" + << ", l2_promotion_kind, oob_fill_kind"; + + std::vector global_shape(tensor_rank); + std::vector global_strides(tensor_rank); + std::vector shared_shape(tensor_rank); + std::vector shared_strides(tensor_rank); + for (size_t i = 0; i < tensor_rank; ++i) { + global_shape[i] = static_cast(args[arg_cnt++].cast()); + } + for (size_t i = 0; i < tensor_rank - 1; ++i) { + global_strides[i] = static_cast(args[arg_cnt++].cast()); + CHECK_EQ(global_strides[i] % 16, 0) << "global strides must be multiple of 16"; + } + for (size_t i = 0; i < tensor_rank; ++i) { + shared_shape[i] = static_cast(args[arg_cnt++].cast()); + CHECK_GE(shared_shape[i], 0) << "boxDim must be non-negative"; + CHECK_LE(shared_shape[i], 256) << "boxDim must be less than or equal to 256"; + } + for (size_t i = 0; i < tensor_rank; ++i) { + shared_strides[i] = static_cast(args[arg_cnt++].cast()); + } + auto interleaved_kind = static_cast(args[arg_cnt++].cast()); + auto swizzle_kind = static_cast(args[arg_cnt++].cast()); + auto l2_promotion_kind = static_cast(args[arg_cnt++].cast()); + auto oob_fill_kind = static_cast(args[arg_cnt++].cast()); + + ICHECK_EQ(tensor_dtype.lanes(), 1) + << "Expect tensor_dtype to have lanes=1, but get " << tensor_dtype; + CUtensorMapDataType cu_dtype; + switch (tensor_dtype.code()) { + case DataType::kInt: + // int + switch (tensor_dtype.bits()) { + case 8: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 32: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_INT32; + break; + case 64: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_INT64; + break; + default: + LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); } - std::cout << "\n"; - std::cout << "global prob stride: "; - for (size_t i = 0; i < tensor_rank; i++) { - std::cout << global_strides[i] << " "; + break; + case DataType::kUInt: + // unsigned int + switch (tensor_dtype.bits()) { + case 8: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 16: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 32: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + case 64: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT64; + break; + default: + LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); } - std::cout << "\n"; - std::cout << "smem box shape: "; - for (size_t i = 0; i < tensor_rank; i++) { - std::cout << shared_shape[i] << " "; + break; + case DataType::kFloat: + // float + switch (tensor_dtype.bits()) { + case 16: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + break; + case 32: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + break; + case 64: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT64; + break; + default: + LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); } - std::cout << "\n"; - std::cout << "smem box stride: "; - for (size_t i = 0; i < tensor_rank; i++) { - std::cout << shared_strides[i] << " "; + break; + case DataType::kBFloat: + // bfloat + switch (tensor_dtype.bits()) { + case 16: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + break; + default: + LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); } - std::cout << "\n"; - CHECK_EQ(res, CUDA_SUCCESS) << "Error in cuTensorMapEncodeTiled: " << errstr; + break; + case DataType::kFloat8_e4m3fn: + // NV float8 e4m3 + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case DataType::kFloat8_e5m2: + // NV float8 e5m2 + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + default: + LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); + } + + // sanity checks per cuTensorMapEncodeTiled requirements + // see + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 + CHECK_EQ((reinterpret_cast(tensor_ptr) & 0b1111), 0); // 16-byte alignment + CHECK_EQ((reinterpret_cast(tensor_map) & 0b111111), 0); // 64-byte alignment + CHECK_LE(tensor_rank, 5) << "cuTensorMapEncodeTiled only supports up to 5D tensors"; + + if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_32B) { + CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 32) + << "CU_TENSOR_MAP_SWIZZLE_32B implies the bounding box inner dimension will be <= 32."; + } else if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_64B) { + CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 64) + << "CU_TENSOR_MAP_SWIZZLE_64B implies the bounding box inner dimension will be <= 64."; + } else if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_128B) { + CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 128) + << "CU_TENSOR_MAP_SWIZZLE_128B implies the bounding box inner dimension will be <= " + "128."; + } + + const cuuint64_t* global_shape_ptr = global_shape.data(); + const cuuint64_t* global_strides_ptr = global_strides.data(); + const uint32_t* shared_shape_ptr = shared_shape.data(); + const uint32_t* shared_strides_ptr = shared_strides.data(); + + CUresult res = + cuTensorMapEncodeTiled(tensor_map, cu_dtype, tensor_rank, tensor_ptr, global_shape_ptr, + global_strides_ptr, shared_shape_ptr, shared_strides_ptr, + interleaved_kind, swizzle_kind, l2_promotion_kind, oob_fill_kind); + const char* errstr; + cuGetErrorString(res, &errstr); + if (res != CUDA_SUCCESS) { + // get error string + const char* error_string = nullptr; + cuGetErrorString(res, &error_string); + std::cerr << "Error in cuTensorMapEncodeTiled: " << error_string << std::endl; + std::cout << "cu_dtype: " << cu_dtype << "\n"; + std::cout << "TMA Desc Addr: " << tensor_map << "\n"; + std::cout << "TMA Interleave: " << interleaved_kind << "\n"; + std::cout << "TMA L2Promotion: " << l2_promotion_kind << "\n"; + std::cout << "TMA OOBFill: " << oob_fill_kind << "\n"; + std::cout << "SMEM Swizzle: " << swizzle_kind << "\n"; + std::cout << "tensor rank: " << tensor_rank << "\n"; + std::cout << "global prob shape: "; + for (size_t i = 0; i < tensor_rank; i++) { + std::cout << global_shape[i] << " "; + } + std::cout << "\n"; + std::cout << "global prob stride: "; + for (size_t i = 0; i < tensor_rank; i++) { + std::cout << global_strides[i] << " "; + } + std::cout << "\n"; + std::cout << "smem box shape: "; + for (size_t i = 0; i < tensor_rank; i++) { + std::cout << shared_shape[i] << " "; } - }); + std::cout << "\n"; + std::cout << "smem box stride: "; + for (size_t i = 0; i < tensor_rank; i++) { + std::cout << shared_strides[i] << " "; + } + std::cout << "\n"; + CHECK_EQ(res, CUDA_SUCCESS) << "Error in cuTensorMapEncodeTiled: " << errstr; + } + }); +}); #endif // CUDA_VERSION >= 12000 } // namespace runtime diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 6d69fde5cdba..c2f6adf1a440 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -298,10 +299,12 @@ Module CUDAModuleLoadBinary(void* strm) { return CUDAModuleCreate(data, fmt, fmap, std::string()); } -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_cubin").set_body_typed(CUDAModuleLoadFile); - -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_ptx").set_body_typed(CUDAModuleLoadFile); - -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_cuda").set_body_typed(CUDAModuleLoadBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.module.loadfile_cubin", CUDAModuleLoadFile) + .def("runtime.module.loadfile_ptx", CUDAModuleLoadFile) + .def("runtime.module.loadbinary_cuda", CUDAModuleLoadBinary); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/l2_cache_flush.cc b/src/runtime/cuda/l2_cache_flush.cc index 726df80de8bc..67ed6884e173 100644 --- a/src/runtime/cuda/l2_cache_flush.cc +++ b/src/runtime/cuda/l2_cache_flush.cc @@ -20,6 +20,7 @@ #include #include +#include #include #include "cuda_common.h" @@ -32,12 +33,14 @@ typedef dmlc::ThreadLocalStore L2FlushStore; L2Flush* L2Flush::ThreadLocal() { return L2FlushStore::Get(); } -TVM_FFI_REGISTER_GLOBAL("l2_cache_flush_cuda") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK(L2Flush::ThreadLocal() != nullptr) << "L2Flush::ThreadLocal do not exist."; - cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; - L2Flush::ThreadLocal()->Flush(stream); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("l2_cache_flush_cuda", [](ffi::PackedArgs args, ffi::Any* rv) { + ICHECK(L2Flush::ThreadLocal() != nullptr) << "L2Flush::ThreadLocal do not exist."; + cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; + L2Flush::ThreadLocal()->Flush(stream); + }); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/debug_compile.cc b/src/runtime/debug_compile.cc deleted file mode 100644 index 4b22e2649462..000000000000 --- a/src/runtime/debug_compile.cc +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/runtime/debug_compile.cc - * \brief File used for debug migration - */ -// #include -#include -#include -#include -#include -#include -#include -#include -#include - -// #include -// #include -// #include -// #include -// #include - -// #include -// #include -// #include - -namespace tvm { -namespace debug { - -using namespace tvm::runtime; - -// TVM_FFI_REGISTER_GLOBAL("tvm.debug.Test").set_body_typed([](PrimExpr value) { -// LOG(INFO) << value; -// return value; -// }); - -} // namespace debug -} // namespace tvm diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc index 32155408fea4..635b882752d8 100644 --- a/src/runtime/device_api.cc +++ b/src/runtime/device_api.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -167,31 +168,31 @@ TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) { return nullptr; } void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) { } -TVM_FFI_REGISTER_GLOBAL("runtime.Device_StreamCreate").set_body_typed([](DLDevice dev) { - return reinterpret_cast(DeviceAPIManager::Get(dev)->CreateStream(dev)); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.Device_StreamCreate", + [](DLDevice dev) { + return reinterpret_cast(DeviceAPIManager::Get(dev)->CreateStream(dev)); + }) + .def("runtime.Device_StreamFree", + [](DLDevice dev, int64_t stream) { + DeviceAPIManager::Get(dev)->FreeStream(dev, reinterpret_cast(stream)); + }) + .def("runtime.Device_SetStream", + [](DLDevice dev, int64_t stream) { + DeviceAPIManager::Get(dev)->SetStream(dev, reinterpret_cast(stream)); + }) + .def("runtime.Device_StreamSync", + [](DLDevice dev, int64_t stream) { + DeviceAPIManager::Get(dev)->StreamSync(dev, reinterpret_cast(stream)); + }) + .def("runtime.Device_StreamSyncFromTo", [](DLDevice dev, int64_t src, int64_t dst) { + DeviceAPIManager::Get(dev)->SyncStreamFromTo(dev, reinterpret_cast(src), + reinterpret_cast(dst)); + }); }); -TVM_FFI_REGISTER_GLOBAL("runtime.Device_StreamFree") - .set_body_typed([](DLDevice dev, int64_t stream) { - DeviceAPIManager::Get(dev)->FreeStream(dev, reinterpret_cast(stream)); - }); - -TVM_FFI_REGISTER_GLOBAL("runtime.Device_SetStream") - .set_body_typed([](DLDevice dev, int64_t stream) { - DeviceAPIManager::Get(dev)->SetStream(dev, reinterpret_cast(stream)); - }); - -TVM_FFI_REGISTER_GLOBAL("runtime.Device_StreamSync") - .set_body_typed([](DLDevice dev, int64_t stream) { - DeviceAPIManager::Get(dev)->StreamSync(dev, reinterpret_cast(stream)); - }); - -TVM_FFI_REGISTER_GLOBAL("runtime.Device_StreamSyncFromTo") - .set_body_typed([](DLDevice dev, int64_t src, int64_t dst) { - DeviceAPIManager::Get(dev)->SyncStreamFromTo(dev, reinterpret_cast(src), - reinterpret_cast(dst)); - }); - // set device api TVM_FFI_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { @@ -202,32 +203,34 @@ TVM_FFI_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) }); // set device api -TVM_FFI_REGISTER_GLOBAL("runtime.GetDeviceAttr") - .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { - DLDevice dev; - dev.device_type = static_cast(args[0].cast()); - dev.device_id = args[1].cast(); - - DeviceAttrKind kind = static_cast(args[2].cast()); - if (kind == kExist) { - DeviceAPI* api = DeviceAPIManager::Get(dev.device_type, true); - if (api != nullptr) { - api->GetAttr(dev, kind, ret); - } else { - *ret = 0; - } - } else { - DeviceAPIManager::Get(dev)->GetAttr(dev, kind, ret); - } - }); - -TVM_FFI_REGISTER_GLOBAL("runtime.TVMSetStream") - .set_body_typed([](int device_type, int device_id, void* stream) { - Device dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - DeviceAPIManager::Get(dev)->SetStream(dev, stream); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("runtime.GetDeviceAttr", + [](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + DLDevice dev; + dev.device_type = static_cast(args[0].cast()); + dev.device_id = args[1].cast(); + + DeviceAttrKind kind = static_cast(args[2].cast()); + if (kind == kExist) { + DeviceAPI* api = DeviceAPIManager::Get(dev.device_type, true); + if (api != nullptr) { + api->GetAttr(dev, kind, ret); + } else { + *ret = 0; + } + } else { + DeviceAPIManager::Get(dev)->GetAttr(dev, kind, ret); + } + }) + .def("runtime.TVMSetStream", [](int device_type, int device_id, void* stream) { + Device dev; + dev.device_type = static_cast(device_type); + dev.device_id = device_id; + DeviceAPIManager::Get(dev)->SetStream(dev, stream); + }); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc index 9cd4c5eda4af..7672fee255ca 100644 --- a/src/runtime/disco/builtin.cc +++ b/src/runtime/disco/builtin.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -121,55 +122,51 @@ void SyncWorker() { } } -TVM_FFI_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule); - -TVM_FFI_REGISTER_GLOBAL("runtime.disco.empty") - .set_body_typed([](ffi::Shape shape, DataType dtype, Optional device, bool worker0_only, - bool in_group) -> Optional { - int worker_id = WorkerId(); - int group_size = - DiscoWorker::ThreadLocal()->num_workers / DiscoWorker::ThreadLocal()->num_groups; - bool is_worker0 = (worker_id == 0 && !in_group) || (in_group && worker_id % group_size == 0); - if (worker0_only && !is_worker0) { - return std::nullopt; - } else { - return DiscoEmptyNDArray(shape, dtype, device); - } - }); - -TVM_FFI_REGISTER_GLOBAL("runtime.disco.allreduce") - .set_body_typed([](NDArray send, ffi::Shape reduce_kind, bool in_group, NDArray recv) { - int kind = IntegerFromShape(reduce_kind); - CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind; - AllReduce(send, static_cast(kind), in_group, recv); - }); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.allgather").set_body_typed(AllGather); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0") - .set_body_typed(BroadcastFromWorker0); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.scatter_from_worker0").set_body_typed(ScatterFromWorker0); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.gather_to_worker0").set_body_typed(GatherToWorker0); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.recv_from_worker0").set_body_typed(RecvFromWorker0); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.send_to_next_group").set_body_typed(SendToNextGroup); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.recv_from_prev_group").set_body_typed(RecvFromPrevGroup); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.send_to_worker").set_body_typed(SendToWorker); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.recv_from_worker").set_body_typed(RecvFromWorker); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.worker_id").set_body_typed([]() -> ffi::Shape { - return ffi::Shape({WorkerId()}); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.disco.load_vm_module", LoadVMModule) + .def("runtime.disco.empty", + [](ffi::Shape shape, DataType dtype, Optional device, bool worker0_only, + bool in_group) -> Optional { + int worker_id = WorkerId(); + int group_size = + DiscoWorker::ThreadLocal()->num_workers / DiscoWorker::ThreadLocal()->num_groups; + bool is_worker0 = + (worker_id == 0 && !in_group) || (in_group && worker_id % group_size == 0); + if (worker0_only && !is_worker0) { + return std::nullopt; + } else { + return DiscoEmptyNDArray(shape, dtype, device); + } + }) + .def("runtime.disco.allreduce", + [](NDArray send, ffi::Shape reduce_kind, bool in_group, NDArray recv) { + int kind = IntegerFromShape(reduce_kind); + CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind; + AllReduce(send, static_cast(kind), in_group, recv); + }) + .def("runtime.disco.allgather", AllGather) + .def("runtime.disco.broadcast_from_worker0", BroadcastFromWorker0) + .def("runtime.disco.scatter_from_worker0", ScatterFromWorker0) + .def("runtime.disco.gather_to_worker0", GatherToWorker0) + .def("runtime.disco.recv_from_worker0", RecvFromWorker0) + .def("runtime.disco.send_to_next_group", SendToNextGroup) + .def("runtime.disco.recv_from_prev_group", RecvFromPrevGroup) + .def("runtime.disco.send_to_worker", SendToWorker) + .def("runtime.disco.recv_from_worker", RecvFromWorker) + .def("runtime.disco.worker_id", []() -> ffi::Shape { return ffi::Shape({WorkerId()}); }) + .def("runtime.disco.worker_rank", []() -> int64_t { return WorkerId(); }) + .def("runtime.disco.device", + []() -> Device { return DiscoWorker::ThreadLocal()->default_device; }) + .def("runtime.disco.bind_worker_to_cpu_core", [](ffi::Shape cpu_ids) { + int worker_id = WorkerId(); + ICHECK_LT(worker_id, static_cast(cpu_ids.size())); + const auto f_set_thread_affinity = tvm::ffi::Function::GetGlobalRequired( + "tvm.runtime.threading.set_current_thread_affinity"); + f_set_thread_affinity(ffi::Shape{cpu_ids[worker_id]}); + }); }); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.worker_rank").set_body_typed([]() -> int64_t { - return WorkerId(); -}); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.device").set_body_typed([]() -> Device { - return DiscoWorker::ThreadLocal()->default_device; -}); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.bind_worker_to_cpu_core") - .set_body_typed([](ffi::Shape cpu_ids) { - int worker_id = WorkerId(); - ICHECK_LT(worker_id, static_cast(cpu_ids.size())); - const auto f_set_thread_affinity = tvm::ffi::Function::GetGlobalRequired( - "tvm.runtime.threading.set_current_thread_affinity"); - f_set_thread_affinity(ffi::Shape{cpu_ids[worker_id]}); - }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc index 778ecc16e5a2..df6a6ccf01d3 100644 --- a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc +++ b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -212,10 +213,13 @@ memory::Storage IPCAllocStorage(ffi::Shape buffer_shape, DLDataType dtype_hint) return storage; } -TVM_FFI_REGISTER_GLOBAL("runtime.disco.cuda_ipc.alloc_storage").set_body_typed(IPCAllocStorage); - -TVM_FFI_REGISTER_GLOBAL("runtime.disco.cuda_ipc.cuda_ipc_memory_allocator_clear") - .set_body_typed([]() { CUDAIPCMemoryAllocator::Global()->Clear(); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.disco.cuda_ipc.alloc_storage", IPCAllocStorage) + .def("runtime.disco.cuda_ipc.cuda_ipc_memory_allocator_clear", + []() { CUDAIPCMemoryAllocator::Global()->Clear(); }); +}); /******************** CUDAIPCMemoryObj ********************/ diff --git a/src/runtime/disco/cuda_ipc/custom_allreduce.cc b/src/runtime/disco/cuda_ipc/custom_allreduce.cc index fa7ef040f3ed..ea3b308dbf0a 100644 --- a/src/runtime/disco/cuda_ipc/custom_allreduce.cc +++ b/src/runtime/disco/cuda_ipc/custom_allreduce.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -112,7 +113,10 @@ void CustomAllReduce(DLTensor* send, int strategy, DLTensor* recv) { ctx->GetDefaultStream()); } -TVM_FFI_REGISTER_GLOBAL("runtime.disco.cuda_ipc.custom_allreduce").set_body_typed(CustomAllReduce); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("runtime.disco.cuda_ipc.custom_allreduce", CustomAllReduce); +}); } // namespace cuda_ipc } // namespace nccl diff --git a/src/runtime/disco/distributed/socket_session.cc b/src/runtime/disco/distributed/socket_session.cc index 6cd012b64e11..1d88e9aa46b9 100644 --- a/src/runtime/disco/distributed/socket_session.cc +++ b/src/runtime/disco/distributed/socket_session.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include @@ -294,8 +295,10 @@ void RemoteSocketSessionEntryPoint(const String& server_host, int server_port, proxy.MainLoop(); } -TVM_FFI_REGISTER_GLOBAL("runtime.disco.RemoteSocketSession") - .set_body_typed(RemoteSocketSessionEntryPoint); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("runtime.disco.RemoteSocketSession", RemoteSocketSessionEntryPoint); +}); Session SocketSession(int num_nodes, int num_workers_per_node, int num_groups, const String& host, int port) { @@ -303,17 +306,21 @@ Session SocketSession(int num_nodes, int num_workers_per_node, int num_groups, c return Session(n); } -TVM_FFI_REGISTER_GLOBAL("runtime.disco.SocketSession").set_body_typed(SocketSession); - -TVM_FFI_REGISTER_GLOBAL("runtime.disco.socket_session_init_workers") - .set_body_typed([](int num_nodes, int node_id, int num_groups, int num_workers_per_node) { - LOG(INFO) << "Initializing worker group with " << num_nodes << " nodes, " - << num_workers_per_node << " workers per node, and " << num_groups << " groups."; - DiscoWorker* worker = DiscoWorker::ThreadLocal(); - worker->num_groups = num_groups; - worker->worker_id = worker->worker_id + node_id * num_workers_per_node; - worker->num_workers = num_nodes * num_workers_per_node; - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.disco.SocketSession", SocketSession) + .def("runtime.disco.socket_session_init_workers", + [](int num_nodes, int node_id, int num_groups, int num_workers_per_node) { + LOG(INFO) << "Initializing worker group with " << num_nodes << " nodes, " + << num_workers_per_node << " workers per node, and " << num_groups + << " groups."; + DiscoWorker* worker = DiscoWorker::ThreadLocal(); + worker->num_groups = num_groups; + worker->worker_id = worker->worker_id + node_id * num_workers_per_node; + worker->num_workers = num_nodes * num_workers_per_node; + }); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc index 52e7833601f3..f5fbbd71e7a5 100644 --- a/src/runtime/disco/loader.cc +++ b/src/runtime/disco/loader.cc @@ -22,6 +22,7 @@ #endif #include #include +#include #include #include #include @@ -405,45 +406,46 @@ Array ShardLoaderObj::LoadAllPresharded() const { return params; } -TVM_FFI_REGISTER_GLOBAL("runtime.disco.ShardLoader").set_body_typed(ShardLoaderObj::Create); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoad") - .set_body_typed([](ObjectRef loader_obj, ffi::Shape weight_index) { - const auto* loader = loader_obj.as(); - CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but gets: " - << loader_obj->GetTypeKey(); - return loader->Load(IntegerFromShape(weight_index)); - }); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadPresharded") - .set_body_typed([](ObjectRef loader_obj, ffi::Shape weight_index) { - const auto* loader = loader_obj.as(); - CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but gets: " - << loader_obj->GetTypeKey(); - return loader->LoadPresharded(IntegerFromShape(weight_index)); - }); - -TVM_FFI_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadAll") - .set_body_typed([](ObjectRef loader_obj) { - const auto* loader = loader_obj.as(); - CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but gets: " - << loader_obj->GetTypeKey(); - return loader->LoadAll(); - }); - -TVM_FFI_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadAllPresharded") - .set_body_typed([](ObjectRef loader_obj) { - const auto* loader = loader_obj.as(); - CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but gets: " - << loader_obj->GetTypeKey(); - return loader->LoadAllPresharded(); - }); - -TVM_FFI_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadParamOnWorker0") - .set_body_typed([](ObjectRef loader_obj, int param_index) { - const auto* loader = loader_obj.as(); - CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but gets: " - << loader_obj->GetTypeKey(); - return loader->LoadParamOnWorker0(param_index); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.disco.ShardLoader", ShardLoaderObj::Create) + .def("runtime.disco.ShardLoaderLoad", + [](ObjectRef loader_obj, ffi::Shape weight_index) { + const auto* loader = loader_obj.as(); + CHECK(loader != nullptr) + << "TypeError: Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey(); + return loader->Load(IntegerFromShape(weight_index)); + }) + .def("runtime.disco.ShardLoaderLoadPresharded", + [](ObjectRef loader_obj, ffi::Shape weight_index) { + const auto* loader = loader_obj.as(); + CHECK(loader != nullptr) + << "TypeError: Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey(); + return loader->LoadPresharded(IntegerFromShape(weight_index)); + }) + .def("runtime.disco.ShardLoaderLoadAll", + [](ObjectRef loader_obj) { + const auto* loader = loader_obj.as(); + CHECK(loader != nullptr) + << "TypeError: Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey(); + return loader->LoadAll(); + }) + .def("runtime.disco.ShardLoaderLoadAllPresharded", + [](ObjectRef loader_obj) { + const auto* loader = loader_obj.as(); + CHECK(loader != nullptr) + << "TypeError: Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey(); + return loader->LoadAllPresharded(); + }) + .def("runtime.disco.ShardLoaderLoadParamOnWorker0", + [](ObjectRef loader_obj, int param_index) { + const auto* loader = loader_obj.as(); + CHECK(loader != nullptr) + << "TypeError: Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey(); + return loader->LoadParamOnWorker0(param_index); + }); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index 2b860b6b63ec..bd41320175f4 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -17,6 +17,8 @@ * under the License. */ +#include + #include #include #include @@ -325,8 +327,10 @@ void SyncWorker() { StreamSynchronize(stream); } -TVM_FFI_REGISTER_GLOBAL("runtime.disco.compiled_ccl").set_body_typed([]() -> String { - return TVM_DISCO_CCL_NAME; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("runtime.disco.compiled_ccl", + []() -> String { return TVM_DISCO_CCL_NAME; }); }); TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl").set_body_typed(InitCCL); TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl_per_worker") diff --git a/src/runtime/disco/process_session.cc b/src/runtime/disco/process_session.cc index 4563079c30b4..fc8dcdfbdb3a 100644 --- a/src/runtime/disco/process_session.cc +++ b/src/runtime/disco/process_session.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include #include @@ -196,8 +197,12 @@ void WorkerProcess(int worker_id, int num_workers, int num_group, int64_t read_f worker.MainLoop(); } -TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionProcess").set_body_typed(Session::ProcessSession); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.WorkerProcess").set_body_typed(WorkerProcess); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.disco.SessionProcess", Session::ProcessSession) + .def("runtime.disco.WorkerProcess", WorkerProcess); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/disco/session.cc b/src/runtime/disco/session.cc index ed2d8575387f..8f9a9fe0a4f0 100644 --- a/src/runtime/disco/session.cc +++ b/src/runtime/disco/session.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include @@ -31,27 +32,25 @@ struct SessionObj::FFI { TVM_REGISTER_OBJECT_TYPE(DRefObj); TVM_REGISTER_OBJECT_TYPE(SessionObj); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionThreaded").set_body_typed(Session::ThreadedSession); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.DRefDebugGetFromRemote") - .set_body_method(&DRefObj::DebugGetFromRemote); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.DRefDebugCopyFrom").set_body_method(&DRefObj::DebugCopyFrom); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionGetNumWorkers") - .set_body_method(&SessionObj::GetNumWorkers); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionGetGlobalFunc") - .set_body_method(&SessionObj::GetGlobalFunc); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionCopyFromWorker0") - .set_body_method(&SessionObj::CopyFromWorker0); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionCopyToWorker0") - .set_body_method(&SessionObj::CopyToWorker0); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionSyncWorker").set_body_method(&SessionObj::SyncWorker); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionInitCCL") // - .set_body_method(&SessionObj::InitCCL); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionCallPacked") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - Session self = args[0].cast(); - *rv = SessionObj::FFI::CallWithPacked(self, args.Slice(1)); - }); -TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionShutdown").set_body_method(&SessionObj::Shutdown); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.disco.SessionThreaded", Session::ThreadedSession) + .def_method("runtime.disco.DRefDebugGetFromRemote", &DRefObj::DebugGetFromRemote) + .def_method("runtime.disco.DRefDebugCopyFrom", &DRefObj::DebugCopyFrom) + .def_method("runtime.disco.SessionGetNumWorkers", &SessionObj::GetNumWorkers) + .def_method("runtime.disco.SessionGetGlobalFunc", &SessionObj::GetGlobalFunc) + .def_method("runtime.disco.SessionCopyFromWorker0", &SessionObj::CopyFromWorker0) + .def_method("runtime.disco.SessionCopyToWorker0", &SessionObj::CopyToWorker0) + .def_method("runtime.disco.SessionSyncWorker", &SessionObj::SyncWorker) + .def_method("runtime.disco.SessionInitCCL", &SessionObj::InitCCL) + .def_packed("runtime.disco.SessionCallPacked", + [](ffi::PackedArgs args, ffi::Any* rv) { + Session self = args[0].cast(); + *rv = SessionObj::FFI::CallWithPacked(self, args.Slice(1)); + }) + .def_method("runtime.disco.SessionShutdown", &SessionObj::Shutdown); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/dso_library.cc b/src/runtime/dso_library.cc index 8a8666691300..e5c7aa7de174 100644 --- a/src/runtime/dso_library.cc +++ b/src/runtime/dso_library.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include "library_module.h" @@ -148,10 +149,12 @@ ObjectPtr CreateDSOLibraryObject(std::string library_path) { return n; } -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_so") - .set_body_typed([](std::string library_path, std::string) { - ObjectPtr n = CreateDSOLibraryObject(library_path); - return CreateModuleFromLibrary(n); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("runtime.module.loadfile_so", [](std::string library_path, std::string) { + ObjectPtr n = CreateDSOLibraryObject(library_path); + return CreateModuleFromLibrary(n); + }); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index 513efbd9fbed..56a12968bbde 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -250,25 +251,24 @@ std::string SaveParams(const Map& params) { return bytes; } -TVM_FFI_REGISTER_GLOBAL("runtime.SaveParams") - .set_body_typed([](const Map& params) { - std::string s = ::tvm::runtime::SaveParams(params); - return ffi::Bytes(std::move(s)); - }); - -TVM_FFI_REGISTER_GLOBAL("runtime.SaveParamsToFile") - .set_body_typed([](const Map& params, const String& path) { - tvm::runtime::SimpleBinaryFileStream strm(path, "wb"); - SaveParams(&strm, params); - }); - -TVM_FFI_REGISTER_GLOBAL("runtime.LoadParams").set_body_typed([](const ffi::Bytes& s) { - return ::tvm::runtime::LoadParams(s); -}); - -TVM_FFI_REGISTER_GLOBAL("runtime.LoadParamsFromFile").set_body_typed([](const String& path) { - tvm::runtime::SimpleBinaryFileStream strm(path, "rb"); - return LoadParams(&strm); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.SaveParams", + [](const Map& params) { + std::string s = ::tvm::runtime::SaveParams(params); + return ffi::Bytes(std::move(s)); + }) + .def("runtime.SaveParamsToFile", + [](const Map& params, const String& path) { + tvm::runtime::SimpleBinaryFileStream strm(path, "wb"); + SaveParams(&strm, params); + }) + .def("runtime.LoadParams", [](const ffi::Bytes& s) { return ::tvm::runtime::LoadParams(s); }) + .def("runtime.LoadParamsFromFile", [](const String& path) { + tvm::runtime::SimpleBinaryFileStream strm(path, "rb"); + return LoadParams(&strm); + }); }); } // namespace runtime diff --git a/src/runtime/hexagon/hexagon_common.cc b/src/runtime/hexagon/hexagon_common.cc index 27bf4ffc0eab..54379564241c 100644 --- a/src/runtime/hexagon/hexagon_common.cc +++ b/src/runtime/hexagon/hexagon_common.cc @@ -23,6 +23,7 @@ #include "hexagon_common.h" #include +#include #include #include @@ -56,8 +57,10 @@ class HexagonTimerNode : public TimerNode { TVM_REGISTER_OBJECT_TYPE(HexagonTimerNode); -TVM_FFI_REGISTER_GLOBAL("profiling.timer.hexagon").set_body_typed([](Device dev) { - return Timer(make_object()); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("profiling.timer.hexagon", + [](Device dev) { return Timer(make_object()); }); }); } // namespace hexagon @@ -89,11 +92,14 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s } } // namespace detail -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_hexagon") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - ObjectPtr n = CreateDSOLibraryObject(args[0].cast()); - *rv = CreateModuleFromLibrary(n); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed( + "runtime.module.loadfile_hexagon", [](ffi::PackedArgs args, ffi::Any* rv) { + ObjectPtr n = CreateDSOLibraryObject(args[0].cast()); + *rv = CreateModuleFromLibrary(n); + }); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index 0bc7e2b80194..2b7788660834 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -25,6 +25,7 @@ #include #include +#include #include #include @@ -190,130 +191,125 @@ void HexagonDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void memcpy(static_cast(to) + to_offset, static_cast(from) + from_offset, size); } -TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.dma_copy_dltensor") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - auto dst = args[0].cast(); - auto src = args[1].cast(); - int size = args[2].cast(); - ICHECK(size > 0); - bool bypass_cache = args[3].cast(); - - int ret = DMA_RETRY; - do { - ret = HexagonDeviceAPI::Global()->UserDMA()->Copy(SYNC_DMA_QUEUE, dst->data, src->data, - size, bypass_cache); - } while (ret == DMA_RETRY); - CHECK(ret == DMA_SUCCESS); - HexagonDeviceAPI::Global()->UserDMA()->Wait(SYNC_DMA_QUEUE, 0); - - *rv = static_cast(0); - }); - -TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.dma_copy") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - uint32_t queue_id = args[0].cast(); - void* dst = args[1].cast(); - void* src = args[2].cast(); - uint32_t size = args[3].cast(); - ICHECK(size > 0); - bool bypass_cache = args[4].cast(); - - int ret = DMA_RETRY; - do { - ret = HexagonDeviceAPI::Global()->UserDMA()->Copy(queue_id, dst, src, size, bypass_cache); - } while (ret == DMA_RETRY); - CHECK(ret == DMA_SUCCESS); - *rv = static_cast(ret); - }); - -TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.dma_wait") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - uint32_t queue_id = args[0].cast(); - int inflight = args[1].cast(); - ICHECK(inflight >= 0); - HexagonDeviceAPI::Global()->UserDMA()->Wait(queue_id, inflight); - *rv = static_cast(0); - }); - -TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.dma_start_group") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - uint32_t queue_id = args[0].cast(); - HexagonDeviceAPI::Global()->UserDMA()->StartGroup(queue_id); - *rv = static_cast(0); - }); - -TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.dma_end_group") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - uint32_t queue_id = args[0].cast(); - HexagonDeviceAPI::Global()->UserDMA()->EndGroup(queue_id); - *rv = static_cast(0); - }); - -TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.alloc_nd") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - int32_t device_type = args[0].cast(); - int32_t device_id = args[1].cast(); - int32_t dtype_code_hint = args[2].cast(); - int32_t dtype_bits_hint = args[3].cast(); - auto scope = args[4].cast(); - CHECK(scope.find("global.vtcm") != std::string::npos); - int64_t ndim = args[5].cast(); - CHECK((ndim == 1 || ndim == 2) && "Hexagon Device API supports only 1d and 2d allocations"); - int64_t* shape = static_cast(args[6].cast()); - - Device dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - - DLDataType type_hint; - type_hint.code = static_cast(dtype_code_hint); - type_hint.bits = static_cast(dtype_bits_hint); - type_hint.lanes = 1; - - HexagonDeviceAPI* hexapi = HexagonDeviceAPI::Global(); - *rv = hexapi->AllocDataSpace(dev, ndim, shape, type_hint, String(scope)); - }); - -TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.free_nd") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - int32_t device_type = args[0].cast(); - int32_t device_id = args[1].cast(); - auto scope = args[2].cast(); - CHECK(scope.find("global.vtcm") != std::string::npos); - void* ptr = args[3].cast(); - - Device dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - - HexagonDeviceAPI* hexapi = HexagonDeviceAPI::Global(); - hexapi->FreeDataSpace(dev, ptr); - *rv = static_cast(0); - }); - -TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.acquire_resources") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - HexagonDeviceAPI* api = HexagonDeviceAPI::Global(); - api->AcquireResources(); - }); - -TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.release_resources") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - HexagonDeviceAPI* api = HexagonDeviceAPI::Global(); - api->ReleaseResources(); - }); - -TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.vtcm_device_bytes") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - HexagonDeviceAPI* api = HexagonDeviceAPI::Global(); - *rv = static_cast(api->VtcmPool()->VtcmDeviceBytes()); - }); - -TVM_FFI_REGISTER_GLOBAL("device_api.hexagon") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = HexagonDeviceAPI::Global(); - *rv = static_cast(ptr); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("device_api.hexagon.dma_copy_dltensor", + [](ffi::PackedArgs args, ffi::Any* rv) { + auto dst = args[0].cast(); + auto src = args[1].cast(); + int size = args[2].cast(); + ICHECK(size > 0); + bool bypass_cache = args[3].cast(); + + int ret = DMA_RETRY; + do { + ret = HexagonDeviceAPI::Global()->UserDMA()->Copy( + SYNC_DMA_QUEUE, dst->data, src->data, size, bypass_cache); + } while (ret == DMA_RETRY); + CHECK(ret == DMA_SUCCESS); + HexagonDeviceAPI::Global()->UserDMA()->Wait(SYNC_DMA_QUEUE, 0); + + *rv = static_cast(0); + }) + .def_packed("device_api.hexagon.dma_copy", + [](ffi::PackedArgs args, ffi::Any* rv) { + uint32_t queue_id = args[0].cast(); + void* dst = args[1].cast(); + void* src = args[2].cast(); + uint32_t size = args[3].cast(); + ICHECK(size > 0); + bool bypass_cache = args[4].cast(); + + int ret = DMA_RETRY; + do { + ret = HexagonDeviceAPI::Global()->UserDMA()->Copy(queue_id, dst, src, size, + bypass_cache); + } while (ret == DMA_RETRY); + CHECK(ret == DMA_SUCCESS); + *rv = static_cast(ret); + }) + .def_packed("device_api.hexagon.dma_wait", + [](ffi::PackedArgs args, ffi::Any* rv) { + uint32_t queue_id = args[0].cast(); + int inflight = args[1].cast(); + ICHECK(inflight >= 0); + HexagonDeviceAPI::Global()->UserDMA()->Wait(queue_id, inflight); + *rv = static_cast(0); + }) + .def_packed("device_api.hexagon.dma_start_group", + [](ffi::PackedArgs args, ffi::Any* rv) { + uint32_t queue_id = args[0].cast(); + HexagonDeviceAPI::Global()->UserDMA()->StartGroup(queue_id); + *rv = static_cast(0); + }) + .def_packed("device_api.hexagon.dma_end_group", + [](ffi::PackedArgs args, ffi::Any* rv) { + uint32_t queue_id = args[0].cast(); + HexagonDeviceAPI::Global()->UserDMA()->EndGroup(queue_id); + *rv = static_cast(0); + }) + .def_packed("device_api.hexagon.alloc_nd", + [](ffi::PackedArgs args, ffi::Any* rv) { + int32_t device_type = args[0].cast(); + int32_t device_id = args[1].cast(); + int32_t dtype_code_hint = args[2].cast(); + int32_t dtype_bits_hint = args[3].cast(); + auto scope = args[4].cast(); + CHECK(scope.find("global.vtcm") != std::string::npos); + int64_t ndim = args[5].cast(); + CHECK((ndim == 1 || ndim == 2) && + "Hexagon Device API supports only 1d and 2d allocations"); + int64_t* shape = static_cast(args[6].cast()); + + Device dev; + dev.device_type = static_cast(device_type); + dev.device_id = device_id; + + DLDataType type_hint; + type_hint.code = static_cast(dtype_code_hint); + type_hint.bits = static_cast(dtype_bits_hint); + type_hint.lanes = 1; + + HexagonDeviceAPI* hexapi = HexagonDeviceAPI::Global(); + *rv = hexapi->AllocDataSpace(dev, ndim, shape, type_hint, String(scope)); + }) + .def_packed("device_api.hexagon.free_nd", + [](ffi::PackedArgs args, ffi::Any* rv) { + int32_t device_type = args[0].cast(); + int32_t device_id = args[1].cast(); + auto scope = args[2].cast(); + CHECK(scope.find("global.vtcm") != std::string::npos); + void* ptr = args[3].cast(); + + Device dev; + dev.device_type = static_cast(device_type); + dev.device_id = device_id; + + HexagonDeviceAPI* hexapi = HexagonDeviceAPI::Global(); + hexapi->FreeDataSpace(dev, ptr); + *rv = static_cast(0); + }) + .def_packed("device_api.hexagon.acquire_resources", + [](ffi::PackedArgs args, ffi::Any* rv) { + HexagonDeviceAPI* api = HexagonDeviceAPI::Global(); + api->AcquireResources(); + }) + .def_packed("device_api.hexagon.release_resources", + [](ffi::PackedArgs args, ffi::Any* rv) { + HexagonDeviceAPI* api = HexagonDeviceAPI::Global(); + api->ReleaseResources(); + }) + .def_packed("device_api.hexagon.vtcm_device_bytes", + [](ffi::PackedArgs args, ffi::Any* rv) { + HexagonDeviceAPI* api = HexagonDeviceAPI::Global(); + *rv = static_cast(api->VtcmPool()->VtcmDeviceBytes()); + }) + .def_packed("device_api.hexagon", [](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = HexagonDeviceAPI::Global(); + *rv = static_cast(ptr); + }); +}); } // namespace hexagon } // namespace runtime diff --git a/src/runtime/hexagon/rpc/android/session.cc b/src/runtime/hexagon/rpc/android/session.cc index 0f71f7265024..1318c2715763 100644 --- a/src/runtime/hexagon/rpc/android/session.cc +++ b/src/runtime/hexagon/rpc/android/session.cc @@ -22,6 +22,7 @@ */ #include +#include extern "C" { #include @@ -109,22 +110,25 @@ class HexagonTransportChannel : public RPCChannel { remote_handle64 _handle = AEE_EUNKNOWN; }; -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK(args.size() >= 4) << args.size() << " is less than 4"; - - auto session_name = args[0].cast(); - int remote_stack_size_bytes = args[1].cast(); - // For simulator, the third parameter is sim_args, ignore it. - int hexagon_rpc_receive_buf_size_bytes = args[3].cast(); - HexagonTransportChannel* hexagon_channel = - new HexagonTransportChannel(hexagon_rpc_URI CDSP_DOMAIN, remote_stack_size_bytes, - static_cast(hexagon_rpc_receive_buf_size_bytes)); - std::unique_ptr channel(hexagon_channel); - auto ep = RPCEndpoint::Create(std::move(channel), session_name, "", nullptr); - auto sess = CreateClientSession(ep); - *rv = CreateRPCSessionModule(sess); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed( + "tvm.contrib.hexagon.create_hexagon_session", [](ffi::PackedArgs args, ffi::Any* rv) { + ICHECK(args.size() >= 4) << args.size() << " is less than 4"; + + auto session_name = args[0].cast(); + int remote_stack_size_bytes = args[1].cast(); + // For simulator, the third parameter is sim_args, ignore it. + int hexagon_rpc_receive_buf_size_bytes = args[3].cast(); + HexagonTransportChannel* hexagon_channel = + new HexagonTransportChannel(hexagon_rpc_URI CDSP_DOMAIN, remote_stack_size_bytes, + static_cast(hexagon_rpc_receive_buf_size_bytes)); + std::unique_ptr channel(hexagon_channel); + auto ep = RPCEndpoint::Create(std::move(channel), session_name, "", nullptr); + auto sess = CreateClientSession(ep); + *rv = CreateRPCSessionModule(sess); + }); +}); } // namespace hexagon } // namespace runtime diff --git a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc index 7880018ff8e8..8a5ea01be579 100644 --- a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc +++ b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc @@ -28,6 +28,7 @@ extern "C" { #include #include #include +#include #include #include @@ -328,24 +329,28 @@ __attribute__((weak)) void _Get_eh_data() {} __attribute__((weak)) void _Parse_fde_instr() {} } -TVM_FFI_REGISTER_GLOBAL("tvm.hexagon.load_module") - .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { - auto soname = args[0].cast(); - tvm::ObjectPtr n = tvm::runtime::CreateDSOLibraryObject(soname); - *rv = CreateModuleFromLibrary(n); - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.hexagon.get_profile_output") - .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { - auto profiling_mode = args[0].cast(); - auto out_file = args[1].cast(); - if (profiling_mode.compare("lwp") == 0) { - *rv = WriteLWPOutput(out_file); - } else { - HEXAGON_PRINT(ERROR, "ERROR: Unsupported profiling mode: %s", profiling_mode.c_str()); - *rv = false; - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("tvm.hexagon.load_module", + [](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { + auto soname = args[0].cast(); + tvm::ObjectPtr n = + tvm::runtime::CreateDSOLibraryObject(soname); + *rv = CreateModuleFromLibrary(n); + }) + .def_packed( + "tvm.hexagon.get_profile_output", [](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { + auto profiling_mode = args[0].cast(); + auto out_file = args[1].cast(); + if (profiling_mode.compare("lwp") == 0) { + *rv = WriteLWPOutput(out_file); + } else { + HEXAGON_PRINT(ERROR, "ERROR: Unsupported profiling mode: %s", profiling_mode.c_str()); + *rv = false; + } + }); +}); void SaveBinaryToFile(const std::string& file_name, const std::string& data) { std::ofstream fs(file_name, std::ios::out | std::ios::binary); @@ -353,9 +358,12 @@ void SaveBinaryToFile(const std::string& file_name, const std::string& data) { fs.write(&data[0], data.length()); } -TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.upload") - .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { - auto file_name = args[0].cast(); - auto data = args[1].cast(); - SaveBinaryToFile(file_name, data); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("tvm.rpc.server.upload", + [](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { + auto file_name = args[0].cast(); + auto data = args[1].cast(); + SaveBinaryToFile(file_name, data); + }); +}); diff --git a/src/runtime/hexagon/rpc/simulator/rpc_server.cc b/src/runtime/hexagon/rpc/simulator/rpc_server.cc index 2301ffc13d17..b9a0504e98b1 100644 --- a/src/runtime/hexagon/rpc/simulator/rpc_server.cc +++ b/src/runtime/hexagon/rpc/simulator/rpc_server.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -332,24 +333,28 @@ __attribute__((weak)) void _Get_eh_data() {} __attribute__((weak)) void _Parse_fde_instr() {} } -TVM_FFI_REGISTER_GLOBAL("tvm.hexagon.load_module") - .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { - auto soname = args[0].cast(); - tvm::ObjectPtr n = tvm::runtime::CreateDSOLibraryObject(soname); - *rv = CreateModuleFromLibrary(n); - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.hexagon.get_profile_output") - .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { - auto profiling_mode = args[0].cast(); - auto out_file = args[1].cast(); - if (profiling_mode.compare("lwp") == 0) { - *rv = WriteLWPOutput(out_file); - } else { - HEXAGON_PRINT(ERROR, "ERROR: Unsupported profiling mode: %s", profiling_mode.c_str()); - *rv = false; - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("tvm.hexagon.load_module", + [](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { + auto soname = args[0].cast(); + tvm::ObjectPtr n = + tvm::runtime::CreateDSOLibraryObject(soname); + *rv = CreateModuleFromLibrary(n); + }) + .def_packed( + "tvm.hexagon.get_profile_output", [](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { + auto profiling_mode = args[0].cast(); + auto out_file = args[1].cast(); + if (profiling_mode.compare("lwp") == 0) { + *rv = WriteLWPOutput(out_file); + } else { + HEXAGON_PRINT(ERROR, "ERROR: Unsupported profiling mode: %s", profiling_mode.c_str()); + *rv = false; + } + }); +}); void SaveBinaryToFile(const std::string& file_name, const std::string& data) { std::ofstream fs(file_name, std::ios::out | std::ios::binary); @@ -357,9 +362,12 @@ void SaveBinaryToFile(const std::string& file_name, const std::string& data) { fs.write(&data[0], data.length()); } -TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.upload") - .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { - auto file_name = args[0].cast(); - auto data = args[1].cast(); - SaveBinaryToFile(file_name, data); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("tvm.rpc.server.upload", + [](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { + auto file_name = args[0].cast(); + auto data = args[1].cast(); + SaveBinaryToFile(file_name, data); + }); +}); diff --git a/src/runtime/hexagon/rpc/simulator/session.cc b/src/runtime/hexagon/rpc/simulator/session.cc index 5eb7beab0f57..96a554e22110 100644 --- a/src/runtime/hexagon/rpc/simulator/session.cc +++ b/src/runtime/hexagon/rpc/simulator/session.cc @@ -19,6 +19,7 @@ #include #include +#include // POSIX includes #include #include @@ -1369,19 +1370,22 @@ std::optional SimulatorRPCChannel::to_nullptr(const detail::Mayb .Default(std::nullopt); } -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK(args.size() >= 4) << args.size() << " is less than 4"; - - auto session_name = args[0].cast(); - int stack_size = args[1].cast(); - auto sim_args = args[2].cast(); - auto channel = std::make_unique(stack_size, sim_args); - std::shared_ptr endpoint = - RPCEndpoint::Create(std::move(channel), session_name, "", nullptr); - std::shared_ptr session = CreateClientSession(endpoint); - *rv = CreateRPCSessionModule(session); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed( + "tvm.contrib.hexagon.create_hexagon_session", [](ffi::PackedArgs args, ffi::Any* rv) { + ICHECK(args.size() >= 4) << args.size() << " is less than 4"; + + auto session_name = args[0].cast(); + int stack_size = args[1].cast(); + auto sim_args = args[2].cast(); + auto channel = std::make_unique(stack_size, sim_args); + std::shared_ptr endpoint = + RPCEndpoint::Create(std::move(channel), session_name, "", nullptr); + std::shared_ptr session = CreateClientSession(endpoint); + *rv = CreateRPCSessionModule(session); + }); +}); } // namespace hexagon } // namespace runtime diff --git a/src/runtime/memory/memory_manager.cc b/src/runtime/memory/memory_manager.cc index b6c2a098d474..c817d66f69d4 100644 --- a/src/runtime/memory/memory_manager.cc +++ b/src/runtime/memory/memory_manager.cc @@ -22,6 +22,7 @@ * \brief Allocate and manage memory for the runtime. */ #include +#include #include #include @@ -264,7 +265,10 @@ void Allocator::Clear() { // Pooled allocator will override this method. } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.memory_manager.clear").set_body_typed(MemoryManager::Clear); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.memory_manager.clear", MemoryManager::Clear); +}); } // namespace memory } // namespace runtime diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index 8722dcfeb60a..46c5fb9f2972 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -22,6 +22,7 @@ */ #include #include +#include #include #include "metal_common.h" @@ -362,13 +363,16 @@ int GetWarpSize(id dev) { MetalThreadEntry* MetalThreadEntry::ThreadLocal() { return MetalThreadStore::Get(); } -TVM_FFI_REGISTER_GLOBAL("device_api.metal").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = MetalWorkspace::Global(); - *rv = static_cast(ptr); -}); - -TVM_FFI_REGISTER_GLOBAL("metal.ResetGlobalState").set_body_typed([]() { - MetalWorkspace::Global()->ReinitializeDefaultStreams(); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("device_api.metal", + [](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = MetalWorkspace::Global(); + *rv = static_cast(ptr); + }) + .def("metal.ResetGlobalState", + []() { MetalWorkspace::Global()->ReinitializeDefaultStreams(); }); }); class MetalTimerNode : public TimerNode { @@ -403,8 +407,10 @@ virtual void Stop() { TVM_REGISTER_OBJECT_TYPE(MetalTimerNode); -TVM_FFI_REGISTER_GLOBAL("profiling.timer.metal").set_body_typed([](Device dev) { - return Timer(make_object(dev)); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("profiling.timer.metal", + [](Device dev) { return Timer(make_object(dev)); }); }); } // namespace metal diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index c1a5ccfdd1a3..f8b2aca1d1bd 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -23,6 +23,7 @@ #include "metal_module.h" #include #include +#include #include #include #include @@ -287,18 +288,21 @@ Module MetalModuleCreate(std::unordered_map smap, return Module(n); } -TVM_FFI_REGISTER_GLOBAL("runtime.module.create_metal_module") - .set_body_typed([](Map smap, std::string fmap_json, std::string fmt, - std::string source) { - std::istringstream stream(fmap_json); - std::unordered_map fmap; - dmlc::JSONReader reader(&stream); - reader.Read(&fmap); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "runtime.module.create_metal_module", + [](Map smap, std::string fmap_json, std::string fmt, std::string source) { + std::istringstream stream(fmap_json); + std::unordered_map fmap; + dmlc::JSONReader reader(&stream); + reader.Read(&fmap); - return MetalModuleCreate( - std::unordered_map(smap.begin(), smap.end()), fmap, fmt, - source); - }); + return MetalModuleCreate( + std::unordered_map(smap.begin(), smap.end()), fmap, fmt, + source); + }); +}); Module MetalModuleLoadBinary(void* strm) { dmlc::Stream* stream = static_cast(strm); @@ -317,6 +321,9 @@ Module MetalModuleLoadBinary(void* strm) { return MetalModuleCreate(smap, fmap, fmt, ""); } -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_metal").set_body_typed(MetalModuleLoadBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("runtime.module.loadbinary_metal", MetalModuleLoadBinary); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/module.cc b/src/runtime/module.cc index d2bc4b2c297a..d71b2c217f82 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -22,6 +22,7 @@ * \brief TVM module system */ #include +#include #include #include @@ -165,53 +166,32 @@ bool RuntimeEnabled(const String& target_str) { return tvm::ffi::Function::GetGlobal(f_name).has_value(); } -TVM_FFI_REGISTER_GLOBAL("runtime.RuntimeEnabled").set_body_typed(RuntimeEnabled); - -TVM_FFI_REGISTER_GLOBAL("runtime.ModuleGetSource").set_body_typed([](Module mod, std::string fmt) { - return mod->GetSource(fmt); -}); - -TVM_FFI_REGISTER_GLOBAL("runtime.ModuleImportsSize").set_body_typed([](Module mod) { - return static_cast(mod->imports().size()); -}); - -TVM_FFI_REGISTER_GLOBAL("runtime.ModuleGetImport").set_body_typed([](Module mod, int index) { - return mod->imports().at(index); -}); - -TVM_FFI_REGISTER_GLOBAL("runtime.ModuleClearImports").set_body_typed([](Module mod) { - mod->ClearImports(); -}); - -TVM_FFI_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) { - return std::string(mod->type_key()); -}); - -TVM_FFI_REGISTER_GLOBAL("runtime.ModuleGetFormat").set_body_typed([](Module mod) { - return mod->GetFormat(); -}); - -TVM_FFI_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile); - -TVM_FFI_REGISTER_GLOBAL("runtime.ModuleSaveToFile") - .set_body_typed([](Module mod, String name, String fmt) { mod->SaveToFile(name, fmt); }); - -TVM_FFI_REGISTER_GLOBAL("runtime.ModuleGetPropertyMask").set_body_typed([](Module mod) { - return mod->GetPropertyMask(); -}); - -TVM_FFI_REGISTER_GLOBAL("runtime.ModuleImplementsFunction") - .set_body_typed([](Module mod, String name, bool query_imports) { - return mod->ImplementsFunction(std::move(name), query_imports); - }); - -TVM_FFI_REGISTER_GLOBAL("runtime.ModuleGetFunction") - .set_body_typed([](Module mod, String name, bool query_imports) { - return mod->GetFunction(name, query_imports); - }); - -TVM_FFI_REGISTER_GLOBAL("runtime.ModuleImport").set_body_typed([](Module mod, Module other) { - mod->Import(other); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.RuntimeEnabled", RuntimeEnabled) + .def("runtime.ModuleGetSource", + [](Module mod, std::string fmt) { return mod->GetSource(fmt); }) + .def("runtime.ModuleImportsSize", + [](Module mod) { return static_cast(mod->imports().size()); }) + .def("runtime.ModuleGetImport", + [](Module mod, int index) { return mod->imports().at(index); }) + .def("runtime.ModuleClearImports", [](Module mod) { mod->ClearImports(); }) + .def("runtime.ModuleGetTypeKey", [](Module mod) { return std::string(mod->type_key()); }) + .def("runtime.ModuleGetFormat", [](Module mod) { return mod->GetFormat(); }) + .def("runtime.ModuleLoadFromFile", Module::LoadFromFile) + .def("runtime.ModuleSaveToFile", + [](Module mod, String name, String fmt) { mod->SaveToFile(name, fmt); }) + .def("runtime.ModuleGetPropertyMask", [](Module mod) { return mod->GetPropertyMask(); }) + .def("runtime.ModuleImplementsFunction", + [](Module mod, String name, bool query_imports) { + return mod->ImplementsFunction(std::move(name), query_imports); + }) + .def("runtime.ModuleGetFunction", + [](Module mod, String name, bool query_imports) { + return mod->GetFunction(name, query_imports); + }) + .def("runtime.ModuleImport", [](Module mod, Module other) { mod->Import(other); }); }); } // namespace runtime diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index f03a83a929ec..8261359e8b8f 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -22,6 +22,7 @@ * \brief NDArray container infratructure. */ #include +#include #include #include #include @@ -215,19 +216,16 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle str using namespace tvm::runtime; -TVM_FFI_REGISTER_GLOBAL("runtime.TVMArrayAllocWithScope").set_body_typed(NDArray::Empty); - -TVM_FFI_REGISTER_GLOBAL("runtime.TVMArrayCreateView").set_body_method(&NDArray::CreateView); - -TVM_FFI_REGISTER_GLOBAL("runtime.TVMArrayCopyFromBytes") - .set_body_typed([](DLTensor* arr, void* data, size_t nbytes) { - ArrayCopyFromBytes(arr, data, nbytes); - }); - -TVM_FFI_REGISTER_GLOBAL("runtime.TVMArrayCopyToBytes") - .set_body_typed([](DLTensor* arr, void* data, size_t nbytes) { - NDArray::CopyToBytes(arr, data, nbytes); - }); - -TVM_FFI_REGISTER_GLOBAL("runtime.TVMArrayCopyFromTo") - .set_body_typed([](DLTensor* from, DLTensor* to) { NDArray::CopyFromTo(from, to); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.TVMArrayAllocWithScope", NDArray::Empty) + .def_method("runtime.TVMArrayCreateView", &NDArray::CreateView) + .def("runtime.TVMArrayCopyFromBytes", + [](DLTensor* arr, void* data, size_t nbytes) { ArrayCopyFromBytes(arr, data, nbytes); }) + .def( + "runtime.TVMArrayCopyToBytes", + [](DLTensor* arr, void* data, size_t nbytes) { NDArray::CopyToBytes(arr, data, nbytes); }) + .def("runtime.TVMArrayCopyFromTo", + [](DLTensor* from, DLTensor* to) { NDArray::CopyFromTo(from, to); }); +}); diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 000e9a94599e..1999625a3aec 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -760,59 +761,62 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic initialized_ = true; } -TVM_FFI_REGISTER_GLOBAL("device_api.opencl.alloc_nd") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - int32_t device_type = args[0].cast(); - int32_t device_id = args[1].cast(); - int32_t dtype_code_hint = args[2].cast(); - int32_t dtype_bits_hint = args[3].cast(); - auto scope = args[4].cast(); - CHECK(scope.find("texture") != std::string::npos); - int64_t ndim = args[5].cast(); - CHECK_EQ(ndim, 2); - int64_t* shape = static_cast(args[6].cast()); - int64_t width = shape[0]; - int64_t height = shape[1]; - - Device dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - - DLDataType type_hint; - type_hint.code = static_cast(dtype_code_hint); - type_hint.bits = static_cast(dtype_bits_hint); - type_hint.lanes = 1; - - *rv = OpenCLWorkspace::Global()->AllocDataSpace(dev, static_cast(width), - static_cast(height), type_hint, - String("global.texture")); - }); - -TVM_FFI_REGISTER_GLOBAL("device_api.opencl.free_nd") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - int32_t device_type = args[0].cast(); - int32_t device_id = args[1].cast(); - auto scope = args[2].cast(); - CHECK(scope.find("texture") != std::string::npos); - void* data = args[3].cast(); - OpenCLWorkspace* ptr = OpenCLWorkspace::Global(); - Device dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - ptr->FreeDataSpace(dev, data); - *rv = static_cast(0); - }); - -TVM_FFI_REGISTER_GLOBAL("device_api.opencl") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = OpenCLWorkspace::Global(); - *rv = static_cast(ptr); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("device_api.opencl.alloc_nd", + [](ffi::PackedArgs args, ffi::Any* rv) { + int32_t device_type = args[0].cast(); + int32_t device_id = args[1].cast(); + int32_t dtype_code_hint = args[2].cast(); + int32_t dtype_bits_hint = args[3].cast(); + auto scope = args[4].cast(); + CHECK(scope.find("texture") != std::string::npos); + int64_t ndim = args[5].cast(); + CHECK_EQ(ndim, 2); + int64_t* shape = static_cast(args[6].cast()); + int64_t width = shape[0]; + int64_t height = shape[1]; + + Device dev; + dev.device_type = static_cast(device_type); + dev.device_id = device_id; + + DLDataType type_hint; + type_hint.code = static_cast(dtype_code_hint); + type_hint.bits = static_cast(dtype_bits_hint); + type_hint.lanes = 1; + + *rv = OpenCLWorkspace::Global()->AllocDataSpace( + dev, static_cast(width), static_cast(height), type_hint, + String("global.texture")); + }) + .def_packed("device_api.opencl.free_nd", + [](ffi::PackedArgs args, ffi::Any* rv) { + int32_t device_type = args[0].cast(); + int32_t device_id = args[1].cast(); + auto scope = args[2].cast(); + CHECK(scope.find("texture") != std::string::npos); + void* data = args[3].cast(); + OpenCLWorkspace* ptr = OpenCLWorkspace::Global(); + Device dev; + dev.device_type = static_cast(device_type); + dev.device_id = device_id; + ptr->FreeDataSpace(dev, data); + *rv = static_cast(0); + }) + .def_packed("device_api.opencl", [](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = OpenCLWorkspace::Global(); + *rv = static_cast(ptr); + }); +}); TVM_REGISTER_OBJECT_TYPE(OpenCLTimerNode); -TVM_FFI_REGISTER_GLOBAL("profiling.timer.opencl").set_body_typed([](Device dev) { - return Timer(make_object(dev)); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("profiling.timer.opencl", + [](Device dev) { return Timer(make_object(dev)); }); }); class OpenCLPooledAllocator final : public memory::PooledAllocator { @@ -894,11 +898,13 @@ class OpenCLPooledAllocator final : public memory::PooledAllocator { } }; -TVM_FFI_REGISTER_GLOBAL("DeviceAllocator.opencl") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - Allocator* alloc = new OpenCLPooledAllocator(); - *rv = static_cast(alloc); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("DeviceAllocator.opencl", [](ffi::PackedArgs args, ffi::Any* rv) { + Allocator* alloc = new OpenCLPooledAllocator(); + *rv = static_cast(alloc); + }); +}); } // namespace cl size_t OpenCLTimerNode::count_timer_execs = 0; diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 8e8ee5a43b78..a19a4e26cf7d 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -389,10 +390,12 @@ Module OpenCLModuleLoadBinary(void* strm) { return OpenCLModuleCreate(data, fmt, fmap, std::string()); } -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_cl").set_body_typed(OpenCLModuleLoadFile); - -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_clbin").set_body_typed(OpenCLModuleLoadFile); - -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_opencl").set_body_typed(OpenCLModuleLoadBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.module.loadfile_cl", OpenCLModuleLoadFile) + .def("runtime.module.loadfile_clbin", OpenCLModuleLoadFile) + .def("runtime.module.loadbinary_opencl", OpenCLModuleLoadBinary); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index 2e1bfba0263a..a62b1adf6f3d 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -84,8 +85,10 @@ class CPUTimerNode : public TimerNode { }; TVM_REGISTER_OBJECT_TYPE(CPUTimerNode); -TVM_FFI_REGISTER_GLOBAL("profiling.timer.cpu").set_body_typed([](Device dev) { - return Timer(make_object()); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("profiling.timer.cpu", + [](Device dev) { return Timer(make_object()); }); }); // keep track of which timers are not defined but we have already warned about @@ -115,7 +118,10 @@ Timer Timer::Start(Device dev) { } } -TVM_FFI_REGISTER_GLOBAL("profiling.start_timer").set_body_typed(Timer::Start); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("profiling.start_timer", Timer::Start); +}); namespace profiling { @@ -788,16 +794,14 @@ TVM_REGISTER_OBJECT_TYPE(ReportNode); TVM_REGISTER_OBJECT_TYPE(DeviceWrapperNode); TVM_REGISTER_OBJECT_TYPE(MetricCollectorNode); -TVM_FFI_REGISTER_GLOBAL("runtime.profiling.AsTable").set_body_method(&ReportNode::AsTable); -TVM_FFI_REGISTER_GLOBAL("runtime.profiling.AsCSV").set_body_typed([](Report n) { - return n->AsCSV(); -}); -TVM_FFI_REGISTER_GLOBAL("runtime.profiling.AsJSON").set_body_typed([](Report n) { - return n->AsJSON(); -}); -TVM_FFI_REGISTER_GLOBAL("runtime.profiling.FromJSON").set_body_typed(Report::FromJSON); -TVM_FFI_REGISTER_GLOBAL("runtime.profiling.DeviceWrapper").set_body_typed([](Device dev) { - return DeviceWrapper(dev); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("runtime.profiling.AsTable", &ReportNode::AsTable) + .def("runtime.profiling.AsCSV", [](Report n) { return n->AsCSV(); }) + .def("runtime.profiling.AsJSON", [](Report n) { return n->AsJSON(); }) + .def("runtime.profiling.FromJSON", Report::FromJSON) + .def("runtime.profiling.DeviceWrapper", [](Device dev) { return DeviceWrapper(dev); }); }); ffi::Function ProfileFunction(Module mod, std::string func_name, int device_type, int device_id, @@ -846,21 +850,22 @@ ffi::Function ProfileFunction(Module mod, std::string func_name, int device_type }); } -TVM_FFI_REGISTER_GLOBAL("runtime.profiling.ProfileFunction") - .set_body_typed)>([](Module mod, String func_name, - int device_type, int device_id, - int warmup_iters, - Array collectors) { - if (mod->type_key() == std::string("rpc")) { - LOG(FATAL) - << "Profiling a module over RPC is not yet supported"; // because we can't send - // MetricCollectors over rpc. - throw; - } else { - return ProfileFunction(mod, func_name, device_type, device_id, warmup_iters, collectors); - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "runtime.profiling.ProfileFunction", + [](Module mod, String func_name, int device_type, int device_id, int warmup_iters, + Array collectors) { + if (mod->type_key() == std::string("rpc")) { + LOG(FATAL) + << "Profiling a module over RPC is not yet supported"; // because we can't send + // MetricCollectors over rpc. + throw; + } else { + return ProfileFunction(mod, func_name, device_type, device_id, warmup_iters, collectors); + } + }); +}); ffi::Function WrapTimeEvaluator(ffi::Function pf, Device dev, int number, int repeat, int min_repeat_ms, int limit_zero_time_iterations, @@ -927,27 +932,22 @@ ffi::Function WrapTimeEvaluator(ffi::Function pf, Device dev, int number, int re return ffi::Function::FromPacked(ftimer); } -TVM_FFI_REGISTER_GLOBAL("runtime.profiling.Report") - .set_body_typed([](Array> calls, - Map> device_metrics, - Map configuration) { - return Report(calls, device_metrics, configuration); - }); - -TVM_FFI_REGISTER_GLOBAL("runtime.profiling.Count").set_body_typed([](int64_t count) { - return ObjectRef(make_object(count)); -}); - -TVM_FFI_REGISTER_GLOBAL("runtime.profiling.Percent").set_body_typed([](double percent) { - return ObjectRef(make_object(percent)); -}); - -TVM_FFI_REGISTER_GLOBAL("runtime.profiling.Duration").set_body_typed([](double duration) { - return ObjectRef(make_object(duration)); -}); - -TVM_FFI_REGISTER_GLOBAL("runtime.profiling.Ratio").set_body_typed([](double ratio) { - return ObjectRef(make_object(ratio)); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.profiling.Report", + [](Array> calls, Map> device_metrics, + Map configuration) { + return Report(calls, device_metrics, configuration); + }) + .def("runtime.profiling.Count", + [](int64_t count) { return ObjectRef(make_object(count)); }) + .def("runtime.profiling.Percent", + [](double percent) { return ObjectRef(make_object(percent)); }) + .def("runtime.profiling.Duration", + [](double duration) { return ObjectRef(make_object(duration)); }) + .def("runtime.profiling.Ratio", + [](double ratio) { return ObjectRef(make_object(ratio)); }); }); } // namespace profiling diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index d0da510389f8..f347cb562c11 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -251,17 +252,20 @@ ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {} ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { return ROCMThreadStore::Get(); } -TVM_FFI_REGISTER_GLOBAL("device_api.rocm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = ROCMDeviceAPI::Global(); - *rv = static_cast(ptr); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("device_api.rocm", + [](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = ROCMDeviceAPI::Global(); + *rv = static_cast(ptr); + }) + .def_packed("device_api.rocm_host", [](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = ROCMDeviceAPI::Global(); + *rv = static_cast(ptr); + }); }); -TVM_FFI_REGISTER_GLOBAL("device_api.rocm_host") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = ROCMDeviceAPI::Global(); - *rv = static_cast(ptr); - }); - class ROCMTimerNode : public TimerNode { public: virtual void Start() { @@ -293,12 +297,12 @@ class ROCMTimerNode : public TimerNode { TVM_REGISTER_OBJECT_TYPE(ROCMTimerNode); -TVM_FFI_REGISTER_GLOBAL("profiling.timer.rocm").set_body_typed([](Device dev) { - return Timer(make_object()); -}); - -TVM_FFI_REGISTER_GLOBAL("runtime.get_rocm_stream").set_body_typed([]() { - return static_cast(ROCMThreadEntry::ThreadLocal()->stream); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("profiling.timer.rocm", [](Device dev) { return Timer(make_object()); }) + .def("runtime.get_rocm_stream", + []() { return static_cast(ROCMThreadEntry::ThreadLocal()->stream); }); }); } // namespace runtime diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index 2d3ba16de247..462f3b543ec6 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -231,12 +232,13 @@ Module ROCMModuleLoadBinary(void* strm) { return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string()); } -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_hsaco").set_body_typed(ROCMModuleLoadBinary); - -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_hip").set_body_typed(ROCMModuleLoadBinary); - -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_hsaco").set_body_typed(ROCMModuleLoadFile); - -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_hip").set_body_typed(ROCMModuleLoadFile); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.module.loadbinary_hsaco", ROCMModuleLoadBinary) + .def("runtime.module.loadbinary_hip", ROCMModuleLoadBinary) + .def("runtime.module.loadfile_hsaco", ROCMModuleLoadFile) + .def("runtime.module.loadfile_hip", ROCMModuleLoadFile); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index ffe031fadfb4..fba34eee5c1b 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -21,6 +21,7 @@ * \file rpc_device_api.cc */ #include +#include #include #include @@ -150,10 +151,13 @@ class RPCDeviceAPI final : public DeviceAPI { } }; -TVM_FFI_REGISTER_GLOBAL("device_api.rpc").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - static RPCDeviceAPI inst; - DeviceAPI* ptr = &inst; - *rv = static_cast(ptr); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("device_api.rpc", [](ffi::PackedArgs args, ffi::Any* rv) { + static RPCDeviceAPI inst; + DeviceAPI* ptr = &inst; + *rv = static_cast(ptr); + }); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_event_impl.cc b/src/runtime/rpc/rpc_event_impl.cc index c178db59a230..e893b975a08c 100644 --- a/src/runtime/rpc/rpc_event_impl.cc +++ b/src/runtime/rpc/rpc_event_impl.cc @@ -22,6 +22,7 @@ * \brief Event driven RPC server implementation. */ #include +#include #include @@ -44,6 +45,9 @@ ffi::Function CreateEventDrivenServer(ffi::Function fsend, std::string name, }); } -TVM_FFI_REGISTER_GLOBAL("rpc.CreateEventDrivenServer").set_body_typed(CreateEventDrivenServer); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("rpc.CreateEventDrivenServer", CreateEventDrivenServer); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc index a64bbb713250..b858c1832c61 100644 --- a/src/runtime/rpc/rpc_local_session.cc +++ b/src/runtime/rpc/rpc_local_session.cc @@ -24,6 +24,7 @@ #include "rpc_local_session.h" #include +#include #include #include @@ -147,8 +148,10 @@ DeviceAPI* LocalSession::GetDeviceAPI(Device dev, bool allow_missing) { return DeviceAPI::Get(dev, allow_missing); } -TVM_FFI_REGISTER_GLOBAL("rpc.LocalSession").set_body_typed([]() { - return CreateRPCSessionModule(std::make_shared()); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("rpc.LocalSession", + []() { return CreateRPCSessionModule(std::make_shared()); }); }); } // namespace runtime diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 67faa3329be5..9ef5abaf5690 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -22,6 +22,7 @@ * \brief RPC runtime module. */ #include +#include #include #include #include @@ -389,91 +390,100 @@ inline void CPUCacheFlush(int begin_index, const ffi::PackedArgs& args) { } } -TVM_FFI_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") - .set_body_typed([](Optional opt_mod, std::string name, int device_type, int device_id, - int number, int repeat, int min_repeat_ms, int limit_zero_time_iterations, - int cooldown_interval_ms, int repeats_to_cooldown, int cache_flush_bytes, - std::string f_preproc_name) { - Device dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - if (opt_mod.defined()) { - Module m = opt_mod.value(); - std::string tkey = m->type_key(); - if (tkey == "rpc") { - return static_cast(m.operator->()) - ->GetTimeEvaluator(name, dev, number, repeat, min_repeat_ms, - limit_zero_time_iterations, cooldown_interval_ms, - repeats_to_cooldown, cache_flush_bytes, f_preproc_name); - } else { - ffi::Function f_preproc; - if (!f_preproc_name.empty()) { - auto pf_preproc = tvm::ffi::Function::GetGlobal(f_preproc_name); - ICHECK(pf_preproc.has_value()) - << "Cannot find " << f_preproc_name << " in the global function"; - f_preproc = *pf_preproc; - } - ffi::Function pf = m.GetFunction(name, true); - CHECK(pf != nullptr) << "Cannot find " << name << "` in the global registry"; - return profiling::WrapTimeEvaluator(pf, dev, number, repeat, min_repeat_ms, - limit_zero_time_iterations, cooldown_interval_ms, - repeats_to_cooldown, cache_flush_bytes, f_preproc); - } - } else { - auto pf = tvm::ffi::Function::GetGlobal(name); - ICHECK(pf.has_value()) << "Cannot find " << name << " in the global function"; - ffi::Function f_preproc; - if (!f_preproc_name.empty()) { - auto pf_preproc = tvm::ffi::Function::GetGlobal(f_preproc_name); - ICHECK(pf_preproc.has_value()) - << "Cannot find " << f_preproc_name << " in the global function"; - f_preproc = *pf_preproc; - } - return profiling::WrapTimeEvaluator(*pf, dev, number, repeat, min_repeat_ms, - limit_zero_time_iterations, cooldown_interval_ms, - repeats_to_cooldown, cache_flush_bytes, f_preproc); - } - }); - -TVM_FFI_REGISTER_GLOBAL("cache_flush_cpu_non_first_arg") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { CPUCacheFlush(1, args); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.RPCTimeEvaluator", + [](Optional opt_mod, std::string name, int device_type, int device_id, + int number, int repeat, int min_repeat_ms, int limit_zero_time_iterations, + int cooldown_interval_ms, int repeats_to_cooldown, int cache_flush_bytes, + std::string f_preproc_name) { + Device dev; + dev.device_type = static_cast(device_type); + dev.device_id = device_id; + if (opt_mod.defined()) { + Module m = opt_mod.value(); + std::string tkey = m->type_key(); + if (tkey == "rpc") { + return static_cast(m.operator->()) + ->GetTimeEvaluator(name, dev, number, repeat, min_repeat_ms, + limit_zero_time_iterations, cooldown_interval_ms, + repeats_to_cooldown, cache_flush_bytes, f_preproc_name); + } else { + ffi::Function f_preproc; + if (!f_preproc_name.empty()) { + auto pf_preproc = tvm::ffi::Function::GetGlobal(f_preproc_name); + ICHECK(pf_preproc.has_value()) + << "Cannot find " << f_preproc_name << " in the global function"; + f_preproc = *pf_preproc; + } + ffi::Function pf = m.GetFunction(name, true); + CHECK(pf != nullptr) << "Cannot find " << name << "` in the global registry"; + return profiling::WrapTimeEvaluator( + pf, dev, number, repeat, min_repeat_ms, limit_zero_time_iterations, + cooldown_interval_ms, repeats_to_cooldown, cache_flush_bytes, f_preproc); + } + } else { + auto pf = tvm::ffi::Function::GetGlobal(name); + ICHECK(pf.has_value()) << "Cannot find " << name << " in the global function"; + ffi::Function f_preproc; + if (!f_preproc_name.empty()) { + auto pf_preproc = tvm::ffi::Function::GetGlobal(f_preproc_name); + ICHECK(pf_preproc.has_value()) + << "Cannot find " << f_preproc_name << " in the global function"; + f_preproc = *pf_preproc; + } + return profiling::WrapTimeEvaluator( + *pf, dev, number, repeat, min_repeat_ms, limit_zero_time_iterations, + cooldown_interval_ms, repeats_to_cooldown, cache_flush_bytes, f_preproc); + } + }) + .def_packed("cache_flush_cpu_non_first_arg", + [](ffi::PackedArgs args, ffi::Any* rv) { CPUCacheFlush(1, args); }); +}); // server function registration. -TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.ImportModule") - .set_body_typed([](Module parent, Module child) { parent->Import(child); }); - -TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.ModuleGetFunction") - .set_body_typed([](Module parent, std::string name, bool query_imports) { - return parent->GetFunction(name, query_imports); - }); - -// functions to access an RPC module. -TVM_FFI_REGISTER_GLOBAL("rpc.LoadRemoteModule").set_body_typed([](Module sess, std::string name) { - std::string tkey = sess->type_key(); - ICHECK_EQ(tkey, "rpc"); - return static_cast(sess.operator->())->LoadModule(name); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tvm.rpc.server.ImportModule", + [](Module parent, Module child) { parent->Import(child); }) + .def("tvm.rpc.server.ModuleGetFunction", + [](Module parent, std::string name, bool query_imports) { + return parent->GetFunction(name, query_imports); + }); }); -TVM_FFI_REGISTER_GLOBAL("rpc.ImportRemoteModule").set_body_typed([](Module parent, Module child) { - std::string tkey = parent->type_key(); - ICHECK_EQ(tkey, "rpc"); - static_cast(parent.operator->())->ImportModule(child); +// functions to access an RPC module. +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("rpc.LoadRemoteModule", + [](Module sess, std::string name) { + std::string tkey = sess->type_key(); + ICHECK_EQ(tkey, "rpc"); + return static_cast(sess.operator->())->LoadModule(name); + }) + .def("rpc.ImportRemoteModule", + [](Module parent, Module child) { + std::string tkey = parent->type_key(); + ICHECK_EQ(tkey, "rpc"); + static_cast(parent.operator->())->ImportModule(child); + }) + .def_packed("rpc.SessTableIndex", + [](ffi::PackedArgs args, ffi::Any* rv) { + Module m = args[0].cast(); + std::string tkey = m->type_key(); + ICHECK_EQ(tkey, "rpc"); + *rv = static_cast(m.operator->())->sess()->table_index(); + }) + .def("tvm.rpc.NDArrayFromRemoteOpaqueHandle", + [](Module mod, void* remote_array, DLTensor* template_tensor, Device dev, + void* ndarray_handle) -> NDArray { + return NDArrayFromRemoteOpaqueHandle(RPCModuleGetSession(mod), remote_array, + template_tensor, dev, ndarray_handle); + }); }); -TVM_FFI_REGISTER_GLOBAL("rpc.SessTableIndex") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - Module m = args[0].cast(); - std::string tkey = m->type_key(); - ICHECK_EQ(tkey, "rpc"); - *rv = static_cast(m.operator->())->sess()->table_index(); - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.rpc.NDArrayFromRemoteOpaqueHandle") - .set_body_typed([](Module mod, void* remote_array, DLTensor* template_tensor, Device dev, - void* ndarray_handle) -> NDArray { - return NDArrayFromRemoteOpaqueHandle(RPCModuleGetSession(mod), remote_array, template_tensor, - dev, ndarray_handle); - }); - } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_pipe_impl.cc b/src/runtime/rpc/rpc_pipe_impl.cc index b9121968137b..970f07744430 100644 --- a/src/runtime/rpc/rpc_pipe_impl.cc +++ b/src/runtime/rpc/rpc_pipe_impl.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -112,14 +113,16 @@ Module CreatePipeClient(std::vector cmd) { return CreateRPCSessionModule(CreateClientSession(endpt)); } -TVM_FFI_REGISTER_GLOBAL("rpc.CreatePipeClient") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - std::vector cmd; - for (int i = 0; i < args.size(); ++i) { - cmd.push_back(args[i].cast()); - } - *rv = CreatePipeClient(cmd); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("rpc.CreatePipeClient", [](ffi::PackedArgs args, ffi::Any* rv) { + std::vector cmd; + for (int i = 0; i < args.size(); ++i) { + cmd.push_back(args[i].cast()); + } + *rv = CreatePipeClient(cmd); + }); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_server_env.cc b/src/runtime/rpc/rpc_server_env.cc index eeb76c2b1512..64f49530db01 100644 --- a/src/runtime/rpc/rpc_server_env.cc +++ b/src/runtime/rpc/rpc_server_env.cc @@ -22,6 +22,7 @@ * \brief Server environment of the RPC. */ #include +#include #include "../file_utils.h" @@ -35,27 +36,28 @@ std::string RPCGetPath(const std::string& name) { return (*f)(name).cast(); } -TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.upload") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - std::string file_name = RPCGetPath(args[0].cast()); - auto data = args[1].cast(); - SaveBinaryToFile(file_name, data); - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.download") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - std::string file_name = RPCGetPath(args[0].cast()); - std::string data; - LoadBinaryFromFile(file_name, &data); - LOG(INFO) << "Download " << file_name << "... nbytes=" << data.size(); - *rv = ffi::Bytes(data); - }); - -TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.remove") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - std::string file_name = RPCGetPath(args[0].cast()); - RemoveFile(file_name); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("tvm.rpc.server.upload", + [](ffi::PackedArgs args, ffi::Any* rv) { + std::string file_name = RPCGetPath(args[0].cast()); + auto data = args[1].cast(); + SaveBinaryToFile(file_name, data); + }) + .def_packed("tvm.rpc.server.download", + [](ffi::PackedArgs args, ffi::Any* rv) { + std::string file_name = RPCGetPath(args[0].cast()); + std::string data; + LoadBinaryFromFile(file_name, &data); + LOG(INFO) << "Download " << file_name << "... nbytes=" << data.size(); + *rv = ffi::Bytes(data); + }) + .def_packed("tvm.rpc.server.remove", [](ffi::PackedArgs args, ffi::Any* rv) { + std::string file_name = RPCGetPath(args[0].cast()); + RemoveFile(file_name); + }); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index 2564242bdf0f..7bfd0ef54abb 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -22,6 +22,7 @@ * \brief Socket based RPC implementation. */ #include +#include #include @@ -121,20 +122,24 @@ void RPCServerLoop(ffi::Function fsend, ffi::Function frecv) { ->ServerLoop(); } -TVM_FFI_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - auto url = args[0].cast(); - int port = args[1].cast(); - auto key = args[2].cast(); - bool enable_logging = args[3].cast(); - *rv = RPCClientConnect(url, port, key, enable_logging, args.Slice(4)); -}); - -TVM_FFI_REGISTER_GLOBAL("rpc.ServerLoop").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - if (auto opt_int = args[0].as()) { - RPCServerLoop(opt_int.value()); - } else { - RPCServerLoop(args[0].cast(), args[1].cast()); - } +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("rpc.Connect", + [](ffi::PackedArgs args, ffi::Any* rv) { + auto url = args[0].cast(); + int port = args[1].cast(); + auto key = args[2].cast(); + bool enable_logging = args[3].cast(); + *rv = RPCClientConnect(url, port, key, enable_logging, args.Slice(4)); + }) + .def_packed("rpc.ServerLoop", [](ffi::PackedArgs args, ffi::Any* rv) { + if (auto opt_int = args[0].as()) { + RPCServerLoop(opt_int.value()); + } else { + RPCServerLoop(args[0].cast(), args[1].cast()); + } + }); }); class SimpleSockHandler : public dmlc::Stream { @@ -162,10 +167,13 @@ class SimpleSockHandler : public dmlc::Stream { support::TCPSocket sock_; }; -TVM_FFI_REGISTER_GLOBAL("rpc.ReturnException").set_body_typed([](int sockfd, String msg) { - auto handler = SimpleSockHandler(sockfd); - RPCReference::ReturnException(msg.c_str(), &handler); - return; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("rpc.ReturnException", [](int sockfd, String msg) { + auto handler = SimpleSockHandler(sockfd); + RPCReference::ReturnException(msg.c_str(), &handler); + return; + }); }); } // namespace runtime diff --git a/src/runtime/static_library.cc b/src/runtime/static_library.cc index 08beb8cbc530..dc63251fdff4 100644 --- a/src/runtime/static_library.cc +++ b/src/runtime/static_library.cc @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -126,9 +127,12 @@ Module LoadStaticLibrary(const std::string& filename, Array func_names) return Module(node); } -TVM_FFI_REGISTER_GLOBAL("runtime.ModuleLoadStaticLibrary").set_body_typed(LoadStaticLibrary); -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_static_library") - .set_body_typed(StaticLibraryNode::LoadFromBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.ModuleLoadStaticLibrary", LoadStaticLibrary) + .def("runtime.module.loadbinary_static_library", StaticLibraryNode::LoadFromBinary); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/system_library.cc b/src/runtime/system_library.cc index 46c08e4afd9a..ab714b636390 100644 --- a/src/runtime/system_library.cc +++ b/src/runtime/system_library.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include @@ -112,14 +113,16 @@ class SystemLibModuleRegistry { std::unordered_map lib_map_; }; -TVM_FFI_REGISTER_GLOBAL("runtime.SystemLib") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - std::string symbol_prefix = ""; - if (args.size() != 0) { - symbol_prefix = args[0].cast(); - } - *rv = SystemLibModuleRegistry::Global()->GetOrCreateModule(symbol_prefix); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("runtime.SystemLib", [](ffi::PackedArgs args, ffi::Any* rv) { + std::string symbol_prefix = ""; + if (args.size() != 0) { + symbol_prefix = args[0].cast(); + } + *rv = SystemLibModuleRegistry::Global()->GetOrCreateModule(symbol_prefix); + }); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index 8d769fbe63ec..a0f64fdf1c0f 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -378,24 +379,26 @@ class ThreadPool { * \brief args[0] is the AffinityMode, args[1] is the number of threads. * args2 is a list of CPUs which is used to set the CPU affinity. */ -TVM_FFI_REGISTER_GLOBAL("runtime.config_threadpool") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - threading::ThreadGroup::AffinityMode mode = - static_cast(args[0].cast()); - int nthreads = args[1].cast(); - std::vector cpus; - if (args.size() >= 3) { - auto cpu_array = args[2].cast>(); - for (auto cpu : cpu_array) { - ICHECK(IsNumber(cpu)) << "The CPU core information '" << cpu << "' is not a number."; - cpus.push_back(std::stoi(cpu)); - } - } - threading::Configure(mode, nthreads, cpus); - }); - -TVM_FFI_REGISTER_GLOBAL("runtime.NumThreads").set_body_typed([]() -> int32_t { - return threading::NumThreads(); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("runtime.config_threadpool", + [](ffi::PackedArgs args, ffi::Any* rv) { + threading::ThreadGroup::AffinityMode mode = + static_cast(args[0].cast()); + int nthreads = args[1].cast(); + std::vector cpus; + if (args.size() >= 3) { + auto cpu_array = args[2].cast>(); + for (auto cpu : cpu_array) { + ICHECK(IsNumber(cpu)) + << "The CPU core information '" << cpu << "' is not a number."; + cpus.push_back(std::stoi(cpu)); + } + } + threading::Configure(mode, nthreads, cpus); + }) + .def("runtime.NumThreads", []() -> int32_t { return threading::NumThreads(); }); }); namespace threading { diff --git a/src/runtime/threading_backend.cc b/src/runtime/threading_backend.cc index ef835f20d171..71341721e6f2 100644 --- a/src/runtime/threading_backend.cc +++ b/src/runtime/threading_backend.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include @@ -437,11 +438,14 @@ int MaxConcurrency() { // This global function can be used by disco runtime to bind processes // to CPUs. -TVM_FFI_REGISTER_GLOBAL("tvm.runtime.threading.set_current_thread_affinity") - .set_body_typed([](ffi::Shape cpu_ids) { - SetThreadAffinity(CURRENT_THREAD_HANDLE, - std::vector{cpu_ids.begin(), cpu_ids.end()}); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "tvm.runtime.threading.set_current_thread_affinity", [](ffi::Shape cpu_ids) { + SetThreadAffinity(CURRENT_THREAD_HANDLE, + std::vector{cpu_ids.begin(), cpu_ids.end()}); + }); +}); } // namespace threading } // namespace runtime diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc index a23e196e5c15..1eb52d73371a 100644 --- a/src/runtime/vm/builtin.cc +++ b/src/runtime/vm/builtin.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -63,7 +64,10 @@ NDArray AllocShapeHeap(void* ctx_ptr, int64_t size) { return alloc->Empty({size}, DLDataType{kDLInt, 64, 1}, vm->devices[host_device_index]); } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.alloc_shape_heap").set_body_typed(AllocShapeHeap); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.alloc_shape_heap", AllocShapeHeap); +}); /*! * \brief Builtin match R.Prim function. @@ -103,7 +107,10 @@ void MatchPrimValue(int64_t input_value, DLTensor* heap, int code_value, int64_t } } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.match_prim_value").set_body_typed(MatchPrimValue); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.match_prim_value", MatchPrimValue); +}); /*! * \brief Builtin match shape function. @@ -154,7 +161,10 @@ void MatchShape(ffi::PackedArgs args, ffi::Any* rv) { } } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.match_shape").set_body_packed(MatchShape); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("vm.builtin.match_shape", MatchShape); +}); /*! * \brief Builtin make prim value function. @@ -178,7 +188,10 @@ int64_t MakePrimValue(DLTensor* heap, int shape_code, int64_t reg) { } } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.make_prim_value").set_body_typed(MakePrimValue); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.make_prim_value", MakePrimValue); +}); /*! * \brief Builtin make shape function. @@ -209,7 +222,10 @@ void MakeShape(ffi::PackedArgs args, ffi::Any* rv) { *rv = ffi::Shape(std::move(shape)); } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.make_shape").set_body_packed(MakeShape); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("vm.builtin.make_shape", MakeShape); +}); /*! * \brief Builtin function to check if arg is Tensor(dtype, ndim) @@ -249,7 +265,10 @@ void CheckTensorInfo(ffi::PackedArgs args, ffi::Any* rv) { } } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.check_tensor_info").set_body_packed(CheckTensorInfo); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("vm.builtin.check_tensor_info", CheckTensorInfo); +}); /*! * \brief Builtin function to check if arg is Shape(ndim) @@ -269,7 +288,10 @@ void CheckShapeInfo(ObjectRef arg, int ndim, Optional err_ctx) { } } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.check_shape_info").set_body_typed(CheckShapeInfo); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.check_shape_info", CheckShapeInfo); +}); /*! * \brief Builtin function to check if arg is PrimValue(dtype) @@ -296,7 +318,10 @@ void CheckPrimValueInfo(ffi::AnyView arg, DataType dtype, Optional err_c } } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.check_prim_value_info").set_body_typed(CheckPrimValueInfo); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.check_prim_value_info", CheckPrimValueInfo); +}); /*! * \brief Builtin function to check if arg is Tuple with size elements. @@ -314,7 +339,10 @@ void CheckTupleInfo(ObjectRef arg, int64_t size, Optional err_ctx) { << " but get a Tuple with " << ptr->size() << " elements."; } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.check_tuple_info").set_body_typed(CheckTupleInfo); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.check_tuple_info", CheckTupleInfo); +}); /*! * \brief Builtin function to check if arg is a callable function. @@ -328,7 +356,10 @@ void CheckFuncInfo(ObjectRef arg, Optional err_ctx) { << arg->GetTypeKey(); } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.check_func_info").set_body_typed(CheckFuncInfo); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.check_func_info", CheckFuncInfo); +}); //------------------------------------------------- // Storage management. @@ -353,70 +384,71 @@ Storage VMAllocStorage(void* ctx_ptr, ffi::Shape buffer_shape, Index device_inde return Storage(buffer, alloc); } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.alloc_storage").set_body_typed(VMAllocStorage); - -TVM_FFI_REGISTER_GLOBAL("vm.builtin.alloc_tensor").set_body_method(&StorageObj::AllocNDArray); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("vm.builtin.alloc_storage", VMAllocStorage) + .def_method("vm.builtin.alloc_tensor", &StorageObj::AllocNDArray); +}); //------------------------------------------------- // Closure function handling, calling convention //------------------------------------------------- -TVM_FFI_REGISTER_GLOBAL("vm.builtin.make_closure") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - VMClosure clo = args[0].cast(); - std::vector saved_args; - saved_args.resize(args.size() - 1); - for (size_t i = 0; i < saved_args.size(); ++i) { - saved_args[i] = args[i + 1]; - } - auto impl = VMClosure::BindLastArgs(clo->impl, saved_args); - *rv = VMClosure(clo->func_name, impl); - }); - -TVM_FFI_REGISTER_GLOBAL("vm.builtin.invoke_closure") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - // args[0]: vm; args[1]: closure; args[2, 3, ...]: function arguments - VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); - ObjectRef vm_closure = args[1].cast(); - vm->InvokeClosurePacked(vm_closure, args.Slice(2), rv); - }); - -TVM_FFI_REGISTER_GLOBAL("vm.builtin.call_tir_dyn") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - ffi::Function func = args[0].cast(); - ffi::Shape to_unpack = args[args.size() - 1].cast(); - size_t num_tensor_args = args.size() - 2; - - std::vector packed_args(num_tensor_args + to_unpack.size()); - std::copy(args.data() + 1, args.data() + args.size() - 1, packed_args.data()); - - for (size_t i = 0; i < to_unpack.size(); ++i) { - packed_args[i + num_tensor_args] = to_unpack[i]; - } - func.CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), rv); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("vm.builtin.make_closure", + [](ffi::PackedArgs args, ffi::Any* rv) { + VMClosure clo = args[0].cast(); + std::vector saved_args; + saved_args.resize(args.size() - 1); + for (size_t i = 0; i < saved_args.size(); ++i) { + saved_args[i] = args[i + 1]; + } + auto impl = VMClosure::BindLastArgs(clo->impl, saved_args); + *rv = VMClosure(clo->func_name, impl); + }) + .def_packed("vm.builtin.invoke_closure", + [](ffi::PackedArgs args, ffi::Any* rv) { + // args[0]: vm; args[1]: closure; args[2, 3, ...]: function arguments + VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); + ObjectRef vm_closure = args[1].cast(); + vm->InvokeClosurePacked(vm_closure, args.Slice(2), rv); + }) + .def_packed("vm.builtin.call_tir_dyn", [](ffi::PackedArgs args, ffi::Any* rv) { + ffi::Function func = args[0].cast(); + ffi::Shape to_unpack = args[args.size() - 1].cast(); + size_t num_tensor_args = args.size() - 2; + + std::vector packed_args(num_tensor_args + to_unpack.size()); + std::copy(args.data() + 1, args.data() + args.size() - 1, packed_args.data()); + + for (size_t i = 0; i < to_unpack.size(); ++i) { + packed_args[i + num_tensor_args] = to_unpack[i]; + } + func.CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), rv); + }); +}); //------------------------------------- // Builtin runtime operators. //------------------------------------- -TVM_FFI_REGISTER_GLOBAL("vm.builtin.shape_of").set_body_method(&NDArray::Shape); - -TVM_FFI_REGISTER_GLOBAL("vm.builtin.copy").set_body_typed([](ffi::Any a) -> ffi::Any { return a; }); - -TVM_FFI_REGISTER_GLOBAL("vm.builtin.reshape") - .set_body_typed([](NDArray data, ffi::Shape new_shape) { - return data.CreateView(new_shape, data->dtype); - }); - -TVM_FFI_REGISTER_GLOBAL("vm.builtin.null_value").set_body_typed([]() -> std::nullptr_t { - return nullptr; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("vm.builtin.shape_of", &NDArray::Shape) + .def("vm.builtin.copy", [](ffi::Any a) -> ffi::Any { return a; }) + .def("vm.builtin.reshape", + [](NDArray data, ffi::Shape new_shape) { + return data.CreateView(new_shape, data->dtype); + }) + .def("vm.builtin.null_value", []() -> std::nullptr_t { return nullptr; }) + .def("vm.builtin.to_device", [](NDArray data, int dev_type, int dev_id) { + Device dst_device = {(DLDeviceType)dev_type, dev_id}; + return data.CopyTo(dst_device); + }); }); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.to_device") - .set_body_typed([](NDArray data, int dev_type, int dev_id) { - Device dst_device = {(DLDeviceType)dev_type, dev_id}; - return data.CopyTo(dst_device); - }); - /*! * \brief Load the scalar value in cond and return the result value. * \param cond The condition @@ -460,106 +492,113 @@ bool ReadIfCond(ffi::AnyView cond) { return result != 0; } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.read_if_cond").set_body_typed(ReadIfCond); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.read_if_cond", ReadIfCond); +}); //------------------------------------- // Debugging API //------------------------------------- -TVM_FFI_REGISTER_GLOBAL("vm.builtin.invoke_debug_func") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) -> void { - ICHECK_GE(args.size(), 3); - int num_args = args.size() - 3; - ObjectRef io_effect = args[0].cast(); - ICHECK(!io_effect.defined()) << "ValueError: IOEffect is expected to be lowered to None."; - String debug_func_name = args[1].cast(); - const auto debug_func = tvm::ffi::Function::GetGlobal(debug_func_name); - CHECK(debug_func.has_value()) << "ValueError: " << debug_func_name << " is not found. " - << "Use the decorator `@tvm.register_func(\"" << debug_func_name - << "\")` to register it."; - String line_info = args[2].cast(); - std::vector call_args(num_args + 1); - { - call_args[0] = line_info; - for (int i = 0; i < num_args; ++i) { - call_args[i + 1] = args[i + 3]; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed( + "vm.builtin.invoke_debug_func", [](ffi::PackedArgs args, ffi::Any* rv) -> void { + ICHECK_GE(args.size(), 3); + int num_args = args.size() - 3; + ObjectRef io_effect = args[0].cast(); + ICHECK(!io_effect.defined()) << "ValueError: IOEffect is expected to be lowered to None."; + String debug_func_name = args[1].cast(); + const auto debug_func = tvm::ffi::Function::GetGlobal(debug_func_name); + CHECK(debug_func.has_value()) << "ValueError: " << debug_func_name << " is not found. " + << "Use the decorator `@tvm.register_func(\"" + << debug_func_name << "\")` to register it."; + String line_info = args[2].cast(); + std::vector call_args(num_args + 1); + { + call_args[0] = line_info; + for (int i = 0; i < num_args; ++i) { + call_args[i + 1] = args[i + 3]; + } } - } - debug_func->CallPacked(ffi::PackedArgs(call_args.data(), call_args.size()), rv); - *rv = io_effect; - }); + debug_func->CallPacked(ffi::PackedArgs(call_args.data(), call_args.size()), rv); + *rv = io_effect; + }); +}); //------------------------------------- // Data structure API //------------------------------------- -TVM_FFI_REGISTER_GLOBAL("vm.builtin.tuple_getitem") - .set_body_typed([](Array arr, int64_t index) { return arr[index]; }); - -TVM_FFI_REGISTER_GLOBAL("vm.builtin.tuple_reset_item") - .set_body_typed([](const ffi::ArrayObj* arr, int64_t index) { - const_cast(arr)->SetItem(index, nullptr); - }); - -TVM_FFI_REGISTER_GLOBAL("vm.builtin.make_tuple") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - Array arr; - for (int i = 0; i < args.size(); ++i) { - arr.push_back(args[i]); - } - *rv = arr; - }); - -TVM_FFI_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data) { - NDArray arr = data; - if (data->device.device_type != kDLCPU) { - arr = data.CopyTo(DLDevice{kDLCPU, 0}); - } - - ICHECK_EQ(arr->ndim, 1); - ICHECK_EQ(arr->dtype.code, kDLInt); - - std::vector out_shape; - for (int i = 0; i < arr.Shape()[0]; ++i) { - int64_t result; - switch (arr->dtype.bits) { - case 16: { - result = reinterpret_cast(arr->data)[i]; - break; - } - case 32: { - result = reinterpret_cast(arr->data)[i]; - break; - } - case 64: { - result = reinterpret_cast(arr->data)[i]; - break; - } - default: - LOG(FATAL) << "Unknown scalar int type: " << DLDataTypeToString(arr->dtype); - throw; - } - out_shape.push_back(result); - } - return ffi::Shape(out_shape); -}); - -TVM_FFI_REGISTER_GLOBAL("vm.builtin.ensure_zero_offset").set_body_typed([](NDArray data) { - if (data->byte_offset == 0) { - return data; - } - auto* device_api = DeviceAPI::Get(data->device); - if (device_api->SupportsDevicePointerArithmeticsOnHost() && - data->byte_offset % tvm::runtime::kAllocAlignment == 0) { - DLManagedTensor* dl_tensor = data.ToDLPack(); - dl_tensor->dl_tensor.data = - reinterpret_cast(dl_tensor->dl_tensor.data) + dl_tensor->dl_tensor.byte_offset; - dl_tensor->dl_tensor.byte_offset = 0; - return NDArray::FromDLPack(dl_tensor); - } else { - auto new_array = NDArray::Empty(data.Shape(), data->dtype, data->device); - new_array.CopyFrom(data); - return new_array; - } +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("vm.builtin.tuple_getitem", + [](Array arr, int64_t index) { return arr[index]; }) + .def("vm.builtin.tuple_reset_item", + [](const ffi::ArrayObj* arr, int64_t index) { + const_cast(arr)->SetItem(index, nullptr); + }) + .def_packed("vm.builtin.make_tuple", + [](ffi::PackedArgs args, ffi::Any* rv) { + Array arr; + for (int i = 0; i < args.size(); ++i) { + arr.push_back(args[i]); + } + *rv = arr; + }) + .def("vm.builtin.tensor_to_shape", + [](NDArray data) { + NDArray arr = data; + if (data->device.device_type != kDLCPU) { + arr = data.CopyTo(DLDevice{kDLCPU, 0}); + } + + ICHECK_EQ(arr->ndim, 1); + ICHECK_EQ(arr->dtype.code, kDLInt); + + std::vector out_shape; + for (int i = 0; i < arr.Shape()[0]; ++i) { + int64_t result; + switch (arr->dtype.bits) { + case 16: { + result = reinterpret_cast(arr->data)[i]; + break; + } + case 32: { + result = reinterpret_cast(arr->data)[i]; + break; + } + case 64: { + result = reinterpret_cast(arr->data)[i]; + break; + } + default: + LOG(FATAL) << "Unknown scalar int type: " << DLDataTypeToString(arr->dtype); + throw; + } + out_shape.push_back(result); + } + return ffi::Shape(out_shape); + }) + .def("vm.builtin.ensure_zero_offset", [](NDArray data) { + if (data->byte_offset == 0) { + return data; + } + auto* device_api = DeviceAPI::Get(data->device); + if (device_api->SupportsDevicePointerArithmeticsOnHost() && + data->byte_offset % tvm::runtime::kAllocAlignment == 0) { + DLManagedTensor* dl_tensor = data.ToDLPack(); + dl_tensor->dl_tensor.data = + reinterpret_cast(dl_tensor->dl_tensor.data) + dl_tensor->dl_tensor.byte_offset; + dl_tensor->dl_tensor.byte_offset = 0; + return NDArray::FromDLPack(dl_tensor); + } else { + auto new_array = NDArray::Empty(data.Shape(), data->dtype, data->device); + new_array.CopyFrom(data); + return new_array; + } + }); }); } // namespace vm diff --git a/src/runtime/vm/cuda/cuda_graph_builtin.cc b/src/runtime/vm/cuda/cuda_graph_builtin.cc index fb0534eec5af..88ac79d67c02 100644 --- a/src/runtime/vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/vm/cuda/cuda_graph_builtin.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include "../../../support/utils.h" @@ -240,30 +241,33 @@ class CUDAGraphExtension : public VMExtension { } }; -TVM_FFI_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK(args.size() == 5 || args.size() == 4); - VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); - auto extension = vm->GetOrCreateExtension(); - auto capture_func = args[1].cast(); - auto func_args = args[2].cast(); - int64_t entry_index = args[3].cast(); - Optional shape_expr = std::nullopt; - if (args.size() == 5) { - shape_expr = args[4].cast(); - } - *rv = extension->RunOrCapture(vm, capture_func, func_args, entry_index, shape_expr); - }); - -TVM_FFI_REGISTER_GLOBAL("vm.builtin.cuda_graph.get_cached_alloc") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK_EQ(args.size(), 3); - VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); - auto extension = vm->GetOrCreateExtension(); - auto alloc_func = args[1].cast(); - int64_t entry_index = args[2].cast(); - *rv = extension->GetCachedAllocation(vm, alloc_func, entry_index); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("vm.builtin.cuda_graph.run_or_capture", + [](ffi::PackedArgs args, ffi::Any* rv) { + ICHECK(args.size() == 5 || args.size() == 4); + VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); + auto extension = vm->GetOrCreateExtension(); + auto capture_func = args[1].cast(); + auto func_args = args[2].cast(); + int64_t entry_index = args[3].cast(); + Optional shape_expr = std::nullopt; + if (args.size() == 5) { + shape_expr = args[4].cast(); + } + *rv = extension->RunOrCapture(vm, capture_func, func_args, entry_index, + shape_expr); + }) + .def_packed("vm.builtin.cuda_graph.get_cached_alloc", [](ffi::PackedArgs args, ffi::Any* rv) { + ICHECK_EQ(args.size(), 3); + VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); + auto extension = vm->GetOrCreateExtension(); + auto alloc_func = args[1].cast(); + int64_t entry_index = args[2].cast(); + *rv = extension->GetCachedAllocation(vm, alloc_func, entry_index); + }); +}); } // namespace vm } // namespace runtime diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index f33ce6bc0edb..77bda63f8d35 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -22,6 +22,7 @@ */ #include +#include #include #include #include @@ -210,8 +211,11 @@ Module VMExecutable::LoadFromBinary(void* stream) { return Module(exec); } -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_relax.VMExecutable") - .set_body_typed(VMExecutable::LoadFromBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("runtime.module.loadbinary_relax.VMExecutable", + VMExecutable::LoadFromBinary); +}); Module VMExecutable::LoadFromFile(const String& file_name) { std::string data; @@ -221,8 +225,10 @@ Module VMExecutable::LoadFromFile(const String& file_name) { return VMExecutable::LoadFromBinary(reinterpret_cast(strm)); } -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_relax.VMExecutable") - .set_body_typed(VMExecutable::LoadFromFile); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("runtime.module.loadfile_relax.VMExecutable", VMExecutable::LoadFromFile); +}); void VMFuncInfo::Save(dmlc::Stream* strm) const { int32_t temp_kind = static_cast(kind); @@ -557,7 +563,10 @@ String VMExecutable::AsPython() const { return String(os.str()); } -TVM_FFI_REGISTER_GLOBAL("relax.ExecutableLoadFromFile").set_body_typed(VMExecutable::LoadFromFile); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.ExecutableLoadFromFile", VMExecutable::LoadFromFile); +}); } // namespace vm } // namespace runtime diff --git a/src/runtime/vm/hexagon/builtin.cc b/src/runtime/vm/hexagon/builtin.cc index c48445130cf0..79909199e6b2 100644 --- a/src/runtime/vm/hexagon/builtin.cc +++ b/src/runtime/vm/hexagon/builtin.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include @@ -31,41 +32,44 @@ namespace tvm { namespace runtime { namespace vm { -TVM_FFI_REGISTER_GLOBAL("vm.builtin.hexagon.dma_copy") - .set_body_typed([](ffi::AnyView vm_ptr, NDArray src_arr, NDArray dst_arr, int queue_id, - bool bypass_cache) { - const DLTensor* dptr = dst_arr.operator->(); - const DLTensor* sptr = src_arr.operator->(); - void* dst = dptr->data; - void* src = sptr->data; - int ret = DMA_RETRY; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("vm.builtin.hexagon.dma_copy", + [](ffi::AnyView vm_ptr, NDArray src_arr, NDArray dst_arr, int queue_id, + bool bypass_cache) { + const DLTensor* dptr = dst_arr.operator->(); + const DLTensor* sptr = src_arr.operator->(); + void* dst = dptr->data; + void* src = sptr->data; + int ret = DMA_RETRY; - CHECK_EQ(GetDataSize(*dptr), GetDataSize(*sptr)); - auto size = GetDataSize(*dptr); - ICHECK(size > 0); - if (bypass_cache) - qurt_mem_cache_clean(reinterpret_cast(src), size, QURT_MEM_CACHE_INVALIDATE, - QURT_MEM_DCACHE); - do { - ret = tvm::runtime::hexagon::HexagonDeviceAPI::Global()->UserDMA()->Copy( - queue_id, dst, src, size, bypass_cache); - } while (ret == DMA_RETRY); - CHECK(ret == DMA_SUCCESS); - }); - -TVM_FFI_REGISTER_GLOBAL("vm.builtin.hexagon.dma_wait") - .set_body_typed([](ffi::AnyView vm_ptr, int queue_id, int inflight_dma, bool bypass_cache, - [[maybe_unused]] NDArray src_arr, [[maybe_unused]] NDArray dst_arr) { - ICHECK(inflight_dma >= 0); - tvm::runtime::hexagon::HexagonDeviceAPI::Global()->UserDMA()->Wait(queue_id, inflight_dma); - if (bypass_cache) { - const DLTensor* dptr = dst_arr.operator->(); - void* dst = dptr->data; - auto size = GetDataSize(*dptr); - qurt_mem_cache_clean(reinterpret_cast(dst), size, QURT_MEM_CACHE_FLUSH, - QURT_MEM_DCACHE); - } - }); + CHECK_EQ(GetDataSize(*dptr), GetDataSize(*sptr)); + auto size = GetDataSize(*dptr); + ICHECK(size > 0); + if (bypass_cache) + qurt_mem_cache_clean(reinterpret_cast(src), size, + QURT_MEM_CACHE_INVALIDATE, QURT_MEM_DCACHE); + do { + ret = tvm::runtime::hexagon::HexagonDeviceAPI::Global()->UserDMA()->Copy( + queue_id, dst, src, size, bypass_cache); + } while (ret == DMA_RETRY); + CHECK(ret == DMA_SUCCESS); + }) + .def("vm.builtin.hexagon.dma_wait", [](ffi::AnyView vm_ptr, int queue_id, int inflight_dma, + bool bypass_cache, [[maybe_unused]] NDArray src_arr, + [[maybe_unused]] NDArray dst_arr) { + ICHECK(inflight_dma >= 0); + tvm::runtime::hexagon::HexagonDeviceAPI::Global()->UserDMA()->Wait(queue_id, inflight_dma); + if (bypass_cache) { + const DLTensor* dptr = dst_arr.operator->(); + void* dst = dptr->data; + auto size = GetDataSize(*dptr); + qurt_mem_cache_clean(reinterpret_cast(dst), size, QURT_MEM_CACHE_FLUSH, + QURT_MEM_DCACHE); + } + }); +}); } // namespace vm } // namespace runtime } // namespace tvm diff --git a/src/runtime/vm/kv_state.cc b/src/runtime/vm/kv_state.cc index 466d41e3d31e..7ca88f4ab80b 100644 --- a/src/runtime/vm/kv_state.cc +++ b/src/runtime/vm/kv_state.cc @@ -19,6 +19,8 @@ #include "kv_state.h" +#include + #include namespace tvm { @@ -31,88 +33,96 @@ TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj); TVM_REGISTER_OBJECT_TYPE(RNNStateObj); // KV State base methods -TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_clear").set_body_method(&KVStateObj::Clear); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_add_sequence") - .set_body_method(&KVStateObj::AddSequence); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_remove_sequence") - .set_body_method(&KVStateObj::RemoveSequence); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_fork_sequence") - .set_body_method(&KVStateObj::ForkSequence); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_popn").set_body_method(&KVStateObj::PopN); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_begin_forward") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - CHECK(args.size() == 3 || args.size() == 4) - << "KVState BeginForward only accepts 3 or 4 arguments"; - KVState kv_state = args[0].cast(); - ffi::Shape seq_ids = args[1].cast(); - ffi::Shape append_lengths = args[2].cast(); - Optional token_tree_parent_ptr; - if (args.size() == 4) { - token_tree_parent_ptr = args[3].cast>(); - } - kv_state->BeginForward(seq_ids, append_lengths, token_tree_parent_ptr); - }); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward").set_body_method(&KVStateObj::EndForward); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("vm.builtin.kv_state_clear", &KVStateObj::Clear) + .def_method("vm.builtin.kv_state_add_sequence", &KVStateObj::AddSequence) + .def_method("vm.builtin.kv_state_remove_sequence", &KVStateObj::RemoveSequence) + .def_method("vm.builtin.kv_state_fork_sequence", &KVStateObj::ForkSequence) + .def_method("vm.builtin.kv_state_popn", &KVStateObj::PopN) + .def_packed("vm.builtin.kv_state_begin_forward", + [](ffi::PackedArgs args, ffi::Any* rv) { + CHECK(args.size() == 3 || args.size() == 4) + << "KVState BeginForward only accepts 3 or 4 arguments"; + KVState kv_state = args[0].cast(); + ffi::Shape seq_ids = args[1].cast(); + ffi::Shape append_lengths = args[2].cast(); + Optional token_tree_parent_ptr; + if (args.size() == 4) { + token_tree_parent_ptr = args[3].cast>(); + } + kv_state->BeginForward(seq_ids, append_lengths, token_tree_parent_ptr); + }) + .def_method("vm.builtin.kv_state_end_forward", &KVStateObj::EndForward); +}); // Attention KV Cache methods -TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_cache_disagg_prepare_recv") - .set_body_method(&AttentionKVCacheObj::DisaggPrepareRecv); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_cache_disagg_mark_send") - .set_body_method(&AttentionKVCacheObj::DisaggMarkSend); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq") - .set_body_method(&AttentionKVCacheObj::EnableSlidingWindowForSeq); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes") - .set_body_method(&AttentionKVCacheObj::CommitAcceptedTokenTreeNodes); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_empty") - .set_body_method(&AttentionKVCacheObj::Empty); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_num_available_pages") - .set_body_method(&AttentionKVCacheObj::GetNumAvailablePages); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_total_sequence_length") - .set_body_method(&AttentionKVCacheObj::GetTotalSequenceLength); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_query_positions") - .set_body_method(&AttentionKVCacheObj::GetQueryPositions); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv") - .set_body_method(&AttentionKVCacheObj::DebugGetKV); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv_mla") - .set_body_method(&AttentionKVCacheObj::DebugGetKVMLA); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention_with_fused_qkv") - .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, - NDArray qkv_data, NDArray o_data) { - kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), std::nullopt, - std::move(o_data), sm_scale); - }); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_self_attention") - .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, NDArray q_data, - NDArray k_data, NDArray v_data, NDArray o_data, NDArray lse_data) { - kv_cache->SelfAttention(layer_id, std::move(q_data), std::move(k_data), std::move(v_data), - std::move(o_data), std::move(lse_data), sm_scale); - }); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_cross_attention") - .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, NDArray q_data, - NDArray o_data, NDArray lse_data) { - kv_cache->CrossAttention(layer_id, std::move(q_data), std::move(o_data), std::move(lse_data), - sm_scale); - }); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_append_mla_kv") - .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, NDArray kv_data) { - kv_cache->AppendMLAKV(layer_id, std::move(kv_data)); - return kv_cache; - }); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_merge_attn_output_inplace") - .set_body_typed([](AttentionKVCache kv_cache, NDArray o_self_attn, NDArray lse_self_attn, - NDArray o_cross_attn, NDArray lse_cross_attn) { - return kv_cache->MergeAttnOutputInplace(std::move(o_self_attn), std::move(lse_self_attn), - std::move(o_cross_attn), std::move(lse_cross_attn)); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("vm.builtin.kv_cache_disagg_prepare_recv", + &AttentionKVCacheObj::DisaggPrepareRecv) + .def_method("vm.builtin.kv_cache_disagg_mark_send", &AttentionKVCacheObj::DisaggMarkSend) + .def_method("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq", + &AttentionKVCacheObj::EnableSlidingWindowForSeq) + .def_method("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes", + &AttentionKVCacheObj::CommitAcceptedTokenTreeNodes) + .def_method("vm.builtin.attention_kv_cache_empty", &AttentionKVCacheObj::Empty) + .def_method("vm.builtin.attention_kv_cache_get_num_available_pages", + &AttentionKVCacheObj::GetNumAvailablePages) + .def_method("vm.builtin.attention_kv_cache_get_total_sequence_length", + &AttentionKVCacheObj::GetTotalSequenceLength) + .def_method("vm.builtin.attention_kv_cache_get_query_positions", + &AttentionKVCacheObj::GetQueryPositions) + .def_method("vm.builtin.attention_kv_cache_debug_get_kv", &AttentionKVCacheObj::DebugGetKV) + .def_method("vm.builtin.attention_kv_cache_debug_get_kv_mla", + &AttentionKVCacheObj::DebugGetKVMLA) + .def("vm.builtin.attention_kv_cache_attention_with_fused_qkv", + [](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, NDArray qkv_data, + NDArray o_data) { + kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), std::nullopt, + std::move(o_data), sm_scale); + }) + .def("vm.builtin.attention_kv_cache_self_attention", + [](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, NDArray q_data, + NDArray k_data, NDArray v_data, NDArray o_data, NDArray lse_data) { + kv_cache->SelfAttention(layer_id, std::move(q_data), std::move(k_data), + std::move(v_data), std::move(o_data), std::move(lse_data), + sm_scale); + }) + .def("vm.builtin.attention_kv_cache_cross_attention", + [](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, NDArray q_data, + NDArray o_data, NDArray lse_data) { + kv_cache->CrossAttention(layer_id, std::move(q_data), std::move(o_data), + std::move(lse_data), sm_scale); + }) + .def("vm.builtin.attention_kv_cache_append_mla_kv", + [](AttentionKVCache kv_cache, int64_t layer_id, NDArray kv_data) { + kv_cache->AppendMLAKV(layer_id, std::move(kv_data)); + return kv_cache; + }) + .def("vm.builtin.attention_kv_cache_merge_attn_output_inplace", + [](AttentionKVCache kv_cache, NDArray o_self_attn, NDArray lse_self_attn, + NDArray o_cross_attn, NDArray lse_cross_attn) { + return kv_cache->MergeAttnOutputInplace( + std::move(o_self_attn), std::move(lse_self_attn), std::move(o_cross_attn), + std::move(lse_cross_attn)); + }); +}); // RNN State methods -TVM_FFI_REGISTER_GLOBAL("vm.builtin.rnn_state_get").set_body_method(&RNNStateObj::Get); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.rnn_state_set") - .set_body_typed([](RNNState state, int64_t layer_id, int64_t state_id, NDArray data) { - state->Set(layer_id, state_id, data); - return state; - }); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.rnn_state_debug_get").set_body_method(&RNNStateObj::DebugGet); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("vm.builtin.rnn_state_get", &RNNStateObj::Get) + .def("vm.builtin.rnn_state_set", + [](RNNState state, int64_t layer_id, int64_t state_id, NDArray data) { + state->Set(layer_id, state_id, data); + return state; + }) + .def_method("vm.builtin.rnn_state_debug_get", &RNNStateObj::DebugGet); +}); } // namespace vm } // namespace runtime diff --git a/src/runtime/vm/lm_support.cc b/src/runtime/vm/lm_support.cc index baf5bf480e37..d510bed0c119 100644 --- a/src/runtime/vm/lm_support.cc +++ b/src/runtime/vm/lm_support.cc @@ -38,6 +38,7 @@ #include #include #include +#include #include #include #include @@ -259,24 +260,30 @@ TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheLegacyObj); //------------------------------------------------- // Register runtime functions //------------------------------------------------- -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_create") - .set_body_typed(AttentionKVCacheLegacy::Create); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.attention_kv_cache_create", AttentionKVCacheLegacy::Create); +}); AttentionKVCacheLegacy AttentionKVCacheUpdate(AttentionKVCacheLegacy cache, NDArray value) { cache->Update(value); return cache; } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_update") - .set_body_typed(AttentionKVCacheUpdate); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.attention_kv_cache_update", AttentionKVCacheUpdate); +}); AttentionKVCacheLegacy AttentionKVCacheAppend(AttentionKVCacheLegacy cache, NDArray value) { cache->Append(value); return cache; } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_append") - .set_body_typed(AttentionKVCacheAppend); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.attention_kv_cache_append", AttentionKVCacheAppend); +}); AttentionKVCacheLegacy AttentionKVCacheWindowOverride(AttentionKVCacheLegacy cache, NDArray value, int64_t max_cache_size) { @@ -284,8 +291,11 @@ AttentionKVCacheLegacy AttentionKVCacheWindowOverride(AttentionKVCacheLegacy cac return cache; } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override") - .set_body_typed(AttentionKVCacheWindowOverride); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.attention_kv_cache_window_override", + AttentionKVCacheWindowOverride); +}); AttentionKVCacheLegacy AttentionKVCacheWindowOverrideWithSinks(AttentionKVCacheLegacy cache, NDArray value, @@ -295,31 +305,37 @@ AttentionKVCacheLegacy AttentionKVCacheWindowOverrideWithSinks(AttentionKVCacheL return cache; } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override_with_sinks") - .set_body_typed(AttentionKVCacheWindowOverrideWithSinks); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.attention_kv_cache_window_override_with_sinks", + AttentionKVCacheWindowOverrideWithSinks); +}); NDArray AttentionKVCacheView(AttentionKVCacheLegacy cache, ffi::Shape shape) { return cache->View(shape); } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_view") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - CHECK(args.size() == 1 || args.size() == 2) - << "ValueError: `vm.builtin.attention_kv_cache_view` expects 1 or 2 arguments, but got " - << args.size() << "."; - AttentionKVCacheLegacy cache = args[0].cast(); - if (args.size() == 2) { - ffi::Shape shape = args[1].cast(); - *rv = cache->View(shape); - } else { - std::vector shape; - shape.push_back(cache->fill_count); - for (int i = 1; i < cache->data->ndim; ++i) { - shape.push_back(cache->data->shape[i]); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed( + "vm.builtin.attention_kv_cache_view", [](ffi::PackedArgs args, ffi::Any* rv) { + CHECK(args.size() == 1 || args.size() == 2) + << "ValueError: `vm.builtin.attention_kv_cache_view` expects 1 or 2 arguments, but got " + << args.size() << "."; + AttentionKVCacheLegacy cache = args[0].cast(); + if (args.size() == 2) { + ffi::Shape shape = args[1].cast(); + *rv = cache->View(shape); + } else { + std::vector shape; + shape.push_back(cache->fill_count); + for (int i = 1; i < cache->data->ndim; ++i) { + shape.push_back(cache->data->shape[i]); + } + *rv = cache->View(ffi::Shape(shape)); } - *rv = cache->View(ffi::Shape(shape)); - } - }); + }); +}); void AttentionKVCacheArrayPopN(Array caches, int64_t n) { for (AttentionKVCacheLegacy cache : caches) { @@ -327,8 +343,10 @@ void AttentionKVCacheArrayPopN(Array caches, int64_t n) } } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_array_popn") - .set_body_typed(AttentionKVCacheArrayPopN); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.attention_kv_cache_array_popn", AttentionKVCacheArrayPopN); +}); void AttentionKVCacheArrayClear(Array caches) { for (AttentionKVCacheLegacy cache : caches) { @@ -336,8 +354,10 @@ void AttentionKVCacheArrayClear(Array caches) { } } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_array_clear") - .set_body_typed(AttentionKVCacheArrayClear); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.attention_kv_cache_array_clear", AttentionKVCacheArrayClear); +}); // NOTE this is a built-in highly related to LM so we put it here. int SampleTopPFromLogits(NDArray logits, double temperature, double top_p, double uniform_sample) { @@ -401,7 +421,10 @@ int SampleTopPFromLogits(NDArray logits, double temperature, double top_p, doubl return data[0].second; } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_logits").set_body_typed(SampleTopPFromLogits); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.sample_top_p_from_logits", SampleTopPFromLogits); +}); int SampleTopPFromProb(NDArray prob, double top_p, double uniform_sample) { ICHECK(prob.IsContiguous()); @@ -496,7 +519,10 @@ int SampleTopPFromProb(NDArray prob, double top_p, double uniform_sample) { return sampled_index; } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_prob").set_body_typed(SampleTopPFromProb); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.sample_top_p_from_prob", SampleTopPFromProb); +}); NDArray MultinomialFromUniform(NDArray prob, NDArray uniform_sample) { ICHECK(prob.IsContiguous()); @@ -533,8 +559,10 @@ NDArray MultinomialFromUniform(NDArray prob, NDArray uniform_sample) { return new_array; } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.multinomial_from_uniform") - .set_body_typed(MultinomialFromUniform); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.multinomial_from_uniform", MultinomialFromUniform); +}); // This is an inplace operation. void ApplyRepetitionPenalty(NDArray logits, NDArray token_ids, double penalty) { @@ -557,8 +585,10 @@ void ApplyRepetitionPenalty(NDArray logits, NDArray token_ids, double penalty) { } } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.apply_repetition_penalty") - .set_body_typed(ApplyRepetitionPenalty); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.apply_repetition_penalty", ApplyRepetitionPenalty); +}); /*! * \brief Apply presence and frequency penalty. This is an inplace operation. @@ -593,8 +623,11 @@ void ApplyPresenceAndFrequencyPenalty(NDArray logits, NDArray token_ids, NDArray } } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.apply_presence_and_frequency_penalty") - .set_body_typed(ApplyPresenceAndFrequencyPenalty); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.apply_presence_and_frequency_penalty", + ApplyPresenceAndFrequencyPenalty); +}); // This is an inplace operation. void ApplySoftmaxWithTemperature(NDArray logits, double temperature) { @@ -618,8 +651,10 @@ void ApplySoftmaxWithTemperature(NDArray logits, double temperature) { } } -TVM_FFI_REGISTER_GLOBAL("vm.builtin.apply_softmax_with_temperature") - .set_body_typed(ApplySoftmaxWithTemperature); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.apply_softmax_with_temperature", ApplySoftmaxWithTemperature); +}); } // namespace vm } // namespace runtime diff --git a/src/runtime/vm/ndarray_cache_support.cc b/src/runtime/vm/ndarray_cache_support.cc index 17f07df70c01..279091dec519 100644 --- a/src/runtime/vm/ndarray_cache_support.cc +++ b/src/runtime/vm/ndarray_cache_support.cc @@ -40,6 +40,7 @@ #endif #include #include +#include #include #include @@ -266,33 +267,38 @@ class NDArrayCache { Map pool_; }; -TVM_FFI_REGISTER_GLOBAL("vm.builtin.ndarray_cache.get").set_body_typed(NDArrayCache::Get); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.ndarray_cache.update") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - CHECK(args.size() == 2 || args.size() == 3); - String name = args[0].cast(); - bool is_override = args.size() == 2 ? false : args[2].cast(); - - NDArray arr; - if (auto opt_nd = args[1].as()) { - arr = opt_nd.value(); - } else { - // We support converting DLTensors to NDArrays as RPC references are always DLTensors - auto tensor = args[1].cast(); - std::vector shape; - for (int64_t i = 0; i < tensor->ndim; i++) { - shape.push_back(tensor->shape[i]); - } - arr = NDArray::Empty(shape, tensor->dtype, tensor->device); - arr.CopyFrom(tensor); - DeviceAPI::Get(arr->device)->StreamSync(arr->device, nullptr); - } - - NDArrayCache::Update(name, arr, is_override); - }); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.ndarray_cache.remove").set_body_typed(NDArrayCache::Remove); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.ndarray_cache.clear").set_body_typed(NDArrayCache::Clear); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.ndarray_cache.load").set_body_typed(NDArrayCache::Load); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("vm.builtin.ndarray_cache.get", NDArrayCache::Get) + .def_packed("vm.builtin.ndarray_cache.update", + [](ffi::PackedArgs args, ffi::Any* rv) { + CHECK(args.size() == 2 || args.size() == 3); + String name = args[0].cast(); + bool is_override = args.size() == 2 ? false : args[2].cast(); + + NDArray arr; + if (auto opt_nd = args[1].as()) { + arr = opt_nd.value(); + } else { + // We support converting DLTensors to NDArrays as RPC references are always + // DLTensors + auto tensor = args[1].cast(); + std::vector shape; + for (int64_t i = 0; i < tensor->ndim; i++) { + shape.push_back(tensor->shape[i]); + } + arr = NDArray::Empty(shape, tensor->dtype, tensor->device); + arr.CopyFrom(tensor); + DeviceAPI::Get(arr->device)->StreamSync(arr->device, nullptr); + } + + NDArrayCache::Update(name, arr, is_override); + }) + .def("vm.builtin.ndarray_cache.remove", NDArrayCache::Remove) + .def("vm.builtin.ndarray_cache.clear", NDArrayCache::Clear) + .def("vm.builtin.ndarray_cache.load", NDArrayCache::Load); +}); // This param module node can be useful to get param dict in RPC mode // when the remote already have loaded parameters from file. @@ -353,27 +359,27 @@ class ParamModuleNode : public runtime::ModuleNode { Array params_; }; -TVM_FFI_REGISTER_GLOBAL("vm.builtin.param_module_from_cache") - .set_body_typed(ParamModuleNode::Create); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.param_module_from_cache_by_name") - .set_body_typed(ParamModuleNode::CreateByName); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.param_array_from_cache") - .set_body_typed(ParamModuleNode::GetParams); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.param_array_from_cache_by_name") - .set_body_typed(ParamModuleNode::GetParamByName); -TVM_FFI_REGISTER_GLOBAL("vm.builtin.param_array_from_cache_by_name_unpacked") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - Array names; - names.reserve(args.size()); - for (int i = 0; i < args.size(); ++i) { - if (!args[i].try_cast()) { - LOG(FATAL) << "ValueError: Expect string as input, but get " << args[i].GetTypeKey() - << " at " << i; - } - names.push_back(args[i].cast()); - } - *rv = ParamModuleNode::GetParamByName(names); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("vm.builtin.param_module_from_cache", ParamModuleNode::Create) + .def("vm.builtin.param_module_from_cache_by_name", ParamModuleNode::CreateByName) + .def("vm.builtin.param_array_from_cache", ParamModuleNode::GetParams) + .def("vm.builtin.param_array_from_cache_by_name", ParamModuleNode::GetParamByName) + .def_packed("vm.builtin.param_array_from_cache_by_name_unpacked", + [](ffi::PackedArgs args, ffi::Any* rv) { + Array names; + names.reserve(args.size()); + for (int i = 0; i < args.size(); ++i) { + if (!args[i].try_cast()) { + LOG(FATAL) << "ValueError: Expect string as input, but get " + << args[i].GetTypeKey() << " at " << i; + } + names.push_back(args[i].cast()); + } + *rv = ParamModuleNode::GetParamByName(names); + }); +}); } // namespace vm } // namespace runtime diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index 2af4b19b06b1..a0bc2cb1a7d4 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -21,6 +21,7 @@ * \brief Runtime paged KV cache object for language models. */ #include +#include #include #include #include @@ -2435,108 +2436,111 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj); // Register runtime functions //------------------------------------------------- -TVM_FFI_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - // Todo: cuda graph arg - CHECK(args.size() == 28 || args.size() == 29) - << "Invalid number of KV cache constructor args: " << args.size(); - ffi::Shape cache_config = args[0].cast(); - ffi::Shape layer_indptr_tuple = args[1].cast(); - int num_groups = 1; - int group_id = 0; - if (DiscoWorker* disco_worker = ThreadLocalDiscoWorker::Get()->worker) { - // In the Disco worker thread - num_groups = disco_worker->num_groups; - group_id = disco_worker->worker_id / (disco_worker->num_workers / num_groups); - } - CHECK_EQ(layer_indptr_tuple.size(), num_groups + 1); - int64_t num_layers = layer_indptr_tuple[group_id + 1] - layer_indptr_tuple[group_id]; - int64_t layer_id_begin_offset = layer_indptr_tuple[group_id]; - int64_t layer_id_end_offset = layer_indptr_tuple[group_id + 1]; - int64_t num_qo_heads = args[2].cast(); - int64_t num_kv_heads = args[3].cast(); - int64_t qk_head_dim = args[4].cast(); - int64_t v_head_dim = args[5].cast(); - ffi::Shape attn_kinds = args[6].cast(); - bool enable_kv_transfer = args[7].cast(); - int rope_mode = args[8].cast(); - double rotary_scale = args[9].cast(); - double rotary_theta = args[10].cast(); - Optional rope_ext_factors = std::nullopt; // args[11] - NDArray init = args[12].cast(); - Optional f_transpose_append_mha = std::nullopt; // args[13] - Optional f_transpose_append_mla = std::nullopt; // args[14] - std::unique_ptr f_attention_prefill_ragged = - ConvertRaggedPrefillFunc(args[15].cast>(), AttnKind::kMHA); - std::unique_ptr f_attention_prefill = - ConvertPagedPrefillFunc(args[16].cast>(), AttnKind::kMHA); - std::unique_ptr f_attention_decode = - ConvertPagedDecodeFunc(args[17].cast>(), AttnKind::kMHA); - std::unique_ptr f_attention_prefill_sliding_window = - ConvertPagedPrefillFunc(args[18].cast>(), AttnKind::kMHA); - std::unique_ptr f_attention_decode_sliding_window = - ConvertPagedDecodeFunc(args[19].cast>(), AttnKind::kMHA); - std::unique_ptr f_attention_prefill_with_tree_mask_paged_kv = - ConvertPagedPrefillTreeMaskFunc(args[20].cast>(), AttnKind::kMHA); - std::unique_ptr f_attention_prefill_with_tree_mask = - ConvertRaggedPrefillTreeMaskFunc(args[21].cast>(), AttnKind::kMHA); - std::unique_ptr f_mla_prefill = - ConvertPagedPrefillFunc(args[22].cast>(), AttnKind::kMLA); - Array f_merge_inplace = args[23].cast>(); - ffi::Function f_split_rotary = args[24].cast(); - ffi::Function f_copy_single_page = args[25].cast(); - ffi::Function f_debug_get_kv = args[26].cast(); - ffi::Function f_compact_copy = args[27].cast(); - - if (auto opt_nd = args[11].as()) { - rope_ext_factors = opt_nd.value(); - } - auto f_convert_optional_packed_func = [&args](int arg_idx) -> Optional { - if (auto opt_func = args[arg_idx].as()) { - return opt_func.value(); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed( + "vm.builtin.paged_attention_kv_cache_create", [](ffi::PackedArgs args, ffi::Any* rv) { + // Todo: cuda graph arg + CHECK(args.size() == 28 || args.size() == 29) + << "Invalid number of KV cache constructor args: " << args.size(); + ffi::Shape cache_config = args[0].cast(); + ffi::Shape layer_indptr_tuple = args[1].cast(); + int num_groups = 1; + int group_id = 0; + if (DiscoWorker* disco_worker = ThreadLocalDiscoWorker::Get()->worker) { + // In the Disco worker thread + num_groups = disco_worker->num_groups; + group_id = disco_worker->worker_id / (disco_worker->num_workers / num_groups); + } + CHECK_EQ(layer_indptr_tuple.size(), num_groups + 1); + int64_t num_layers = layer_indptr_tuple[group_id + 1] - layer_indptr_tuple[group_id]; + int64_t layer_id_begin_offset = layer_indptr_tuple[group_id]; + int64_t layer_id_end_offset = layer_indptr_tuple[group_id + 1]; + int64_t num_qo_heads = args[2].cast(); + int64_t num_kv_heads = args[3].cast(); + int64_t qk_head_dim = args[4].cast(); + int64_t v_head_dim = args[5].cast(); + ffi::Shape attn_kinds = args[6].cast(); + bool enable_kv_transfer = args[7].cast(); + int rope_mode = args[8].cast(); + double rotary_scale = args[9].cast(); + double rotary_theta = args[10].cast(); + Optional rope_ext_factors = std::nullopt; // args[11] + NDArray init = args[12].cast(); + Optional f_transpose_append_mha = std::nullopt; // args[13] + Optional f_transpose_append_mla = std::nullopt; // args[14] + std::unique_ptr f_attention_prefill_ragged = + ConvertRaggedPrefillFunc(args[15].cast>(), AttnKind::kMHA); + std::unique_ptr f_attention_prefill = + ConvertPagedPrefillFunc(args[16].cast>(), AttnKind::kMHA); + std::unique_ptr f_attention_decode = + ConvertPagedDecodeFunc(args[17].cast>(), AttnKind::kMHA); + std::unique_ptr f_attention_prefill_sliding_window = + ConvertPagedPrefillFunc(args[18].cast>(), AttnKind::kMHA); + std::unique_ptr f_attention_decode_sliding_window = + ConvertPagedDecodeFunc(args[19].cast>(), AttnKind::kMHA); + std::unique_ptr f_attention_prefill_with_tree_mask_paged_kv = + ConvertPagedPrefillTreeMaskFunc(args[20].cast>(), AttnKind::kMHA); + std::unique_ptr f_attention_prefill_with_tree_mask = + ConvertRaggedPrefillTreeMaskFunc(args[21].cast>(), AttnKind::kMHA); + std::unique_ptr f_mla_prefill = + ConvertPagedPrefillFunc(args[22].cast>(), AttnKind::kMLA); + Array f_merge_inplace = args[23].cast>(); + ffi::Function f_split_rotary = args[24].cast(); + ffi::Function f_copy_single_page = args[25].cast(); + ffi::Function f_debug_get_kv = args[26].cast(); + ffi::Function f_compact_copy = args[27].cast(); + + if (auto opt_nd = args[11].as()) { + rope_ext_factors = opt_nd.value(); + } + auto f_convert_optional_packed_func = [&args](int arg_idx) -> Optional { + if (auto opt_func = args[arg_idx].as()) { + return opt_func.value(); + } + return std::nullopt; + }; + f_transpose_append_mha = f_convert_optional_packed_func(13); + f_transpose_append_mla = f_convert_optional_packed_func(14); + CHECK(!f_merge_inplace.empty()) << "Merge inplace function is not defined."; + + std::vector attn_kinds_vec; + attn_kinds_vec.reserve(attn_kinds.size()); + for (int64_t attn_kind : attn_kinds) { + attn_kinds_vec.push_back(static_cast(attn_kind)); } - return std::nullopt; - }; - f_transpose_append_mha = f_convert_optional_packed_func(13); - f_transpose_append_mla = f_convert_optional_packed_func(14); - CHECK(!f_merge_inplace.empty()) << "Merge inplace function is not defined."; - - std::vector attn_kinds_vec; - attn_kinds_vec.reserve(attn_kinds.size()); - for (int64_t attn_kind : attn_kinds) { - attn_kinds_vec.push_back(static_cast(attn_kind)); - } - CHECK_EQ(cache_config.size(), 5); - int64_t reserved_num_seqs = cache_config[0]; - int64_t total_token_capacity = cache_config[1]; - int64_t prefill_chunk_size = cache_config[2]; - int64_t page_size = cache_config[3]; - bool support_sliding_window = cache_config[4]; - int64_t num_total_pages = (total_token_capacity + page_size - 1) / page_size + 1; - if (support_sliding_window) { - // When sliding window is enabled, each sequence may use two more pages at most. - num_total_pages += reserved_num_seqs * 2; - } - // NOTE: We will remove this legacy construction after finishing the transition phase. - // Some `ffi::Function()` here are placeholders that will be filled. - ObjectPtr n = make_object( - page_size, num_layers, layer_id_begin_offset, layer_id_end_offset, num_qo_heads, - num_kv_heads, qk_head_dim, v_head_dim, attn_kinds_vec, reserved_num_seqs, num_total_pages, - prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode), rotary_scale, - rotary_theta, std::move(rope_ext_factors), enable_kv_transfer, // - init->dtype, init->device, // - std::move(f_transpose_append_mha), std::move(f_transpose_append_mla), - std::move(f_compact_copy), std::move(f_attention_prefill_ragged), - std::move(f_attention_prefill), std::move(f_attention_decode), - std::move(f_attention_prefill_sliding_window), - std::move(f_attention_decode_sliding_window), - std::move(f_attention_prefill_with_tree_mask_paged_kv), // - std::move(f_attention_prefill_with_tree_mask), // - std::move(f_mla_prefill), std::move(f_merge_inplace), std::move(f_split_rotary), - std::move(f_copy_single_page), std::move(f_debug_get_kv)); - *rv = AttentionKVCache(std::move(n)); - }); + CHECK_EQ(cache_config.size(), 5); + int64_t reserved_num_seqs = cache_config[0]; + int64_t total_token_capacity = cache_config[1]; + int64_t prefill_chunk_size = cache_config[2]; + int64_t page_size = cache_config[3]; + bool support_sliding_window = cache_config[4]; + int64_t num_total_pages = (total_token_capacity + page_size - 1) / page_size + 1; + if (support_sliding_window) { + // When sliding window is enabled, each sequence may use two more pages at most. + num_total_pages += reserved_num_seqs * 2; + } + // NOTE: We will remove this legacy construction after finishing the transition phase. + // Some `ffi::Function()` here are placeholders that will be filled. + ObjectPtr n = make_object( + page_size, num_layers, layer_id_begin_offset, layer_id_end_offset, num_qo_heads, + num_kv_heads, qk_head_dim, v_head_dim, attn_kinds_vec, reserved_num_seqs, + num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode), + rotary_scale, rotary_theta, std::move(rope_ext_factors), enable_kv_transfer, // + init->dtype, init->device, // + std::move(f_transpose_append_mha), std::move(f_transpose_append_mla), + std::move(f_compact_copy), std::move(f_attention_prefill_ragged), + std::move(f_attention_prefill), std::move(f_attention_decode), + std::move(f_attention_prefill_sliding_window), + std::move(f_attention_decode_sliding_window), + std::move(f_attention_prefill_with_tree_mask_paged_kv), // + std::move(f_attention_prefill_with_tree_mask), // + std::move(f_mla_prefill), std::move(f_merge_inplace), std::move(f_split_rotary), + std::move(f_copy_single_page), std::move(f_debug_get_kv)); + *rv = AttentionKVCache(std::move(n)); + }); +}); } // namespace vm } // namespace runtime diff --git a/src/runtime/vm/rnn_state.cc b/src/runtime/vm/rnn_state.cc index e7a655d392f8..40d0f6318439 100644 --- a/src/runtime/vm/rnn_state.cc +++ b/src/runtime/vm/rnn_state.cc @@ -21,6 +21,8 @@ * \brief Runtime RNN state object for space state models. */ +#include + #include #include @@ -464,36 +466,37 @@ TVM_REGISTER_OBJECT_TYPE(RNNStateImpObj); // Register runtime functions //------------------------------------------------- -TVM_FFI_REGISTER_GLOBAL("vm.builtin.rnn_state_create") - .set_body_typed([](int64_t num_layers, // - int64_t reserved_num_seqs, // - int64_t max_history, // - Array f_gets, // - Array f_sets, // - Array init_layer_value) { - CHECK_GT(num_layers, 0) << "The number of layers should be greater than 0."; - CHECK_GT(reserved_num_seqs, 0) - << "The number of reserved sequences should be greater than 0."; - CHECK_GE(max_history, 0) << "The maximum history length should be greater or equal than 0."; - CHECK_GT(init_layer_value.size(), 0) - << "The number of states per layer should be greater than 0."; - Device device = init_layer_value[0]->device; - for (const NDArray& state : init_layer_value) { - CHECK(state->device.device_type == device.device_type && - state->device.device_id == device.device_id) - << "The device type of all states should be the same."; - } - CHECK_EQ(f_gets.size(), init_layer_value.size()) - << "The number of state getters should be the same as the number of states per layer, " - << "but got " << f_gets.size() << " and " << init_layer_value.size() << " respectively."; - CHECK_EQ(f_sets.size(), init_layer_value.size()) - << "The number of state setters should be the same as the number of states per layer, " - << "but got " << f_sets.size() << " and " << init_layer_value.size() << " respectively."; - ObjectPtr n = - make_object(num_layers, reserved_num_seqs, max_history, device, - std::move(f_gets), std::move(f_sets), init_layer_value); - return RNNState(std::move(n)); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("vm.builtin.rnn_state_create", [](int64_t num_layers, // + int64_t reserved_num_seqs, // + int64_t max_history, // + Array f_gets, // + Array f_sets, // + Array init_layer_value) { + CHECK_GT(num_layers, 0) << "The number of layers should be greater than 0."; + CHECK_GT(reserved_num_seqs, 0) << "The number of reserved sequences should be greater than 0."; + CHECK_GE(max_history, 0) << "The maximum history length should be greater or equal than 0."; + CHECK_GT(init_layer_value.size(), 0) + << "The number of states per layer should be greater than 0."; + Device device = init_layer_value[0]->device; + for (const NDArray& state : init_layer_value) { + CHECK(state->device.device_type == device.device_type && + state->device.device_id == device.device_id) + << "The device type of all states should be the same."; + } + CHECK_EQ(f_gets.size(), init_layer_value.size()) + << "The number of state getters should be the same as the number of states per layer, " + << "but got " << f_gets.size() << " and " << init_layer_value.size() << " respectively."; + CHECK_EQ(f_sets.size(), init_layer_value.size()) + << "The number of state setters should be the same as the number of states per layer, " + << "but got " << f_sets.size() << " and " << init_layer_value.size() << " respectively."; + ObjectPtr n = + make_object(num_layers, reserved_num_seqs, max_history, device, + std::move(f_gets), std::move(f_sets), init_layer_value); + return RNNState(std::move(n)); + }); +}); } // namespace vm } // namespace runtime diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index 12181f8c159d..322a7f467f2d 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -19,6 +19,8 @@ #include "vulkan_device_api.h" +#include + #include #include #include @@ -455,18 +457,20 @@ VulkanDevice& VulkanDeviceAPI::device(size_t device_id) { return const_cast(const_cast(this)->device(device_id)); } -TVM_FFI_REGISTER_GLOBAL("device_api.vulkan") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = VulkanDeviceAPI::Global(); - *rv = static_cast(ptr); - }); - -TVM_FFI_REGISTER_GLOBAL("device_api.vulkan.get_target_property") - .set_body_typed([](Device dev, const std::string& property) { - ffi::Any rv; - VulkanDeviceAPI::Global()->GetTargetProperty(dev, property, &rv); - return rv; - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("device_api.vulkan", + [](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = VulkanDeviceAPI::Global(); + *rv = static_cast(ptr); + }) + .def("device_api.vulkan.get_target_property", [](Device dev, const std::string& property) { + ffi::Any rv; + VulkanDeviceAPI::Global()->GetTargetProperty(dev, property, &rv); + return rv; + }); +}); } // namespace vulkan } // namespace runtime diff --git a/src/runtime/vulkan/vulkan_module.cc b/src/runtime/vulkan/vulkan_module.cc index 063dc5bde009..55982a3c05db 100644 --- a/src/runtime/vulkan/vulkan_module.cc +++ b/src/runtime/vulkan/vulkan_module.cc @@ -21,6 +21,7 @@ #include #include +#include #include "../file_utils.h" #include "vulkan_wrapped_func.h" @@ -64,9 +65,12 @@ Module VulkanModuleLoadBinary(void* strm) { return VulkanModuleCreate(smap, fmap, ""); } -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile); - -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.module.loadfile_vulkan", VulkanModuleLoadFile) + .def("runtime.module.loadbinary_vulkan", VulkanModuleLoadBinary); +}); } // namespace vulkan } // namespace runtime diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc index 20808635ed0c..8f69f33e6d39 100644 --- a/src/script/ir_builder/base.cc +++ b/src/script/ir_builder/base.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include @@ -106,24 +107,20 @@ void Namer::Name(ObjectRef node, String name) { TVM_REGISTER_NODE_TYPE(IRBuilderFrameNode); TVM_REGISTER_NODE_TYPE(IRBuilderNode); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameEnter") - .set_body_method(&IRBuilderFrameNode::EnterWithScope); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameExit") - .set_body_method(&IRBuilderFrameNode::ExitWithScope); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameAddCallback") - .set_body_method(&IRBuilderFrameNode::AddCallback); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return IRBuilder(); }); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter") - .set_body_method(&IRBuilder::EnterWithScope); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit") - .set_body_method(&IRBuilder::ExitWithScope); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderIsInScope") - .set_body_typed(IRBuilder::IsInScope); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet") - .set_body_method(&IRBuilderNode::Get); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderName") - .set_body_typed(IRBuilder::Name); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("script.ir_builder.IRBuilderFrameEnter", &IRBuilderFrameNode::EnterWithScope) + .def_method("script.ir_builder.IRBuilderFrameExit", &IRBuilderFrameNode::ExitWithScope) + .def_method("script.ir_builder.IRBuilderFrameAddCallback", &IRBuilderFrameNode::AddCallback) + .def("script.ir_builder.IRBuilder", []() { return IRBuilder(); }) + .def_method("script.ir_builder.IRBuilderEnter", &IRBuilder::EnterWithScope) + .def_method("script.ir_builder.IRBuilderExit", &IRBuilder::ExitWithScope) + .def("script.ir_builder.IRBuilderCurrent", IRBuilder::Current) + .def("script.ir_builder.IRBuilderIsInScope", IRBuilder::IsInScope) + .def_method("script.ir_builder.IRBuilderGet", &IRBuilderNode::Get) + .def("script.ir_builder.IRBuilderName", IRBuilder::Name); +}); } // namespace ir_builder } // namespace script diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index 111f2beae328..30ee9e987f2e 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include #include @@ -158,14 +159,18 @@ VDevice LookupVDevice(String target_kind, int device_index) { return VDevice(); } -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.ModuleGetAttr").set_body_typed(ModuleGetAttr); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.ModuleSetAttr").set_body_typed(ModuleSetAttr); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.ModuleGlobalInfos").set_body_typed(ModuleGlobalInfos); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.LookupVDevice").set_body_typed(LookupVDevice); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("script.ir_builder.ir.IRModule", IRModule) + .def("script.ir_builder.ir.DeclFunction", DeclFunction) + .def("script.ir_builder.ir.DefFunction", DefFunction) + .def("script.ir_builder.ir.ModuleAttrs", ModuleAttrs) + .def("script.ir_builder.ir.ModuleGetAttr", ModuleGetAttr) + .def("script.ir_builder.ir.ModuleSetAttr", ModuleSetAttr) + .def("script.ir_builder.ir.ModuleGlobalInfos", ModuleGlobalInfos) + .def("script.ir_builder.ir.LookupVDevice", LookupVDevice); +}); } // namespace ir } // namespace ir_builder diff --git a/src/script/ir_builder/relax/distributed.cc b/src/script/ir_builder/relax/distributed.cc index fcf9e0eb2c5b..dd8bb41e41f6 100644 --- a/src/script/ir_builder/relax/distributed.cc +++ b/src/script/ir_builder/relax/distributed.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include @@ -54,8 +55,10 @@ Expr MakeCallTIRDist(Expr func, Tuple args, Array #include #include #include @@ -144,13 +145,16 @@ void FuncRetValue(const tvm::relax::Expr& value) { frame->output = std::move(normalized_value); } -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.Function").set_body_typed(Function); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.Arg").set_body_typed(Arg); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.FuncName").set_body_typed(FuncName); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.FuncAttrs").set_body_typed(FuncAttrs); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetStructInfo") - .set_body_typed(FuncRetStructInfo); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetValue").set_body_typed(FuncRetValue); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("script.ir_builder.relax.Function", Function) + .def("script.ir_builder.relax.Arg", Arg) + .def("script.ir_builder.relax.FuncName", FuncName) + .def("script.ir_builder.relax.FuncAttrs", FuncAttrs) + .def("script.ir_builder.relax.FuncRetStructInfo", FuncRetStructInfo) + .def("script.ir_builder.relax.FuncRetValue", FuncRetValue); +}); ///////////////////////////// BindingBlock ////////////////////////////// @@ -192,10 +196,13 @@ void DataflowBlockOutput(const Array& vars) { } } -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.Dataflow").set_body_typed(Dataflow); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.BindingBlock").set_body_typed(BindingBlock); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.DataflowBlockOutput") - .set_body_typed(DataflowBlockOutput); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("script.ir_builder.relax.Dataflow", Dataflow) + .def("script.ir_builder.relax.BindingBlock", BindingBlock) + .def("script.ir_builder.relax.DataflowBlockOutput", DataflowBlockOutput); +}); /////////////////////////////// Bindings /////////////////////////////// @@ -237,9 +244,13 @@ tvm::relax::Var EmitVarBinding(const tvm::relax::VarBinding& binding) { return binding->var; } -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.Emit").set_body_typed(Emit); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchCast").set_body_typed(EmitMatchCast); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.EmitVarBinding").set_body_typed(EmitVarBinding); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("script.ir_builder.relax.Emit", Emit) + .def("script.ir_builder.relax.EmitMatchCast", EmitMatchCast) + .def("script.ir_builder.relax.EmitVarBinding", EmitVarBinding); +}); /////////////////////////////// SeqExpr /////////////////////////////// @@ -248,7 +259,10 @@ SeqExprFrame SeqExpr() { return SeqExprFrame(n); } -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.SeqExpr").set_body_typed(SeqExpr); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.ir_builder.relax.SeqExpr", SeqExpr); +}); ///////////////////////////// If Then Else ///////////////////////////// @@ -270,9 +284,13 @@ ElseFrame Else() { return ElseFrame(n); } -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.If").set_body_typed(If); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.Then").set_body_typed(Then); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.Else").set_body_typed(Else); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("script.ir_builder.relax.If", If) + .def("script.ir_builder.relax.Then", Then) + .def("script.ir_builder.relax.Else", Else); +}); } // namespace relax } // namespace ir_builder diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 7ef970fa0971..66852a20adda 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -18,6 +18,7 @@ */ #include #include +#include #include #include "./utils.h" @@ -688,76 +689,74 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) Namer::Name(var->var, name); }); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Buffer").set_body_typed(BufferDecl); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.PrimFunc").set_body_typed(PrimFunc); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Arg") - .set_body_typed([](String name, ObjectRef obj) -> ObjectRef { - using namespace tvm::tir; - if (auto var = obj.as()) { - return Arg(name, var.value()); - } - if (auto buffer = obj.as()) { - return Arg(name, buffer.value()); - } - LOG(FATAL) << "ValueError: Unexpected type for TIR Arg: " << obj->GetTypeKey(); - throw; - }); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.FuncName").set_body_typed(FuncName); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.FuncAttrs").set_body_typed(FuncAttrs); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.FuncRet").set_body_typed(FuncRet); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.MatchBuffer").set_body_typed(MatchBuffer); - -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Init").set_body_typed(Init); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Where").set_body_typed(Where); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Reads").set_body_typed(Reads); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Writes").set_body_typed(Writes); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.BlockAttrs").set_body_typed(BlockAttrs); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AllocBuffer").set_body_typed(AllocBuffer); - -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AxisSpatial").set_body_typed(axis::Spatial); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AxisReduce").set_body_typed(axis::Reduce); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AxisScan").set_body_typed(axis::Scan); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AxisOpaque").set_body_typed(axis::Opaque); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AxisRemap").set_body_typed(axis::Remap); - -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Serial").set_body_typed(Serial); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Parallel").set_body_typed(Parallel); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Vectorized").set_body_typed(Vectorized); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Unroll").set_body_typed(Unroll); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.ThreadBinding").set_body_typed(ThreadBinding); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Grid").set_body_typed(Grid); - -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Assert").set_body_typed(Assert); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.LetStmt").set_body_typed(LetStmt); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.LegacyLetStmt").set_body_typed(LegacyLetStmt); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Allocate").set_body_typed(Allocate); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AllocateConst").set_body_typed(AllocateConst); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Realize").set_body_typed(Realize); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Attr").set_body_typed(Attr); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.While").set_body_typed(While); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.If").set_body_typed(If); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Then").set_body_typed(Then); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Else").set_body_typed(Else); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread") - .set_body_typed([](ffi::Variant thread_tag_or_var, PrimExpr extent) { - if (auto var = thread_tag_or_var.as()) { - return LaunchThread(var.value(), extent); - } else if (auto str = thread_tag_or_var.as()) { - return LaunchThread(str.value(), extent); - } else { - LOG(FATAL) << "ValueError: Unexpected type for TIR LaunchThread: " - << thread_tag_or_var.GetTypeKey(); - throw; - } - }); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread); - -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferStore); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate); - -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Ptr").set_body_typed(Ptr); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("script.ir_builder.tir.Buffer", BufferDecl) + .def("script.ir_builder.tir.PrimFunc", PrimFunc) + .def("script.ir_builder.tir.Arg", + [](String name, ObjectRef obj) -> ObjectRef { + using namespace tvm::tir; + if (auto var = obj.as()) { + return Arg(name, var.value()); + } + if (auto buffer = obj.as()) { + return Arg(name, buffer.value()); + } + LOG(FATAL) << "ValueError: Unexpected type for TIR Arg: " << obj->GetTypeKey(); + throw; + }) + .def("script.ir_builder.tir.FuncName", FuncName) + .def("script.ir_builder.tir.FuncAttrs", FuncAttrs) + .def("script.ir_builder.tir.FuncRet", FuncRet) + .def("script.ir_builder.tir.MatchBuffer", MatchBuffer) + .def("script.ir_builder.tir.Block", Block) + .def("script.ir_builder.tir.Init", Init) + .def("script.ir_builder.tir.Where", Where) + .def("script.ir_builder.tir.Reads", Reads) + .def("script.ir_builder.tir.Writes", Writes) + .def("script.ir_builder.tir.BlockAttrs", BlockAttrs) + .def("script.ir_builder.tir.AllocBuffer", AllocBuffer) + .def("script.ir_builder.tir.AxisSpatial", axis::Spatial) + .def("script.ir_builder.tir.AxisReduce", axis::Reduce) + .def("script.ir_builder.tir.AxisScan", axis::Scan) + .def("script.ir_builder.tir.AxisOpaque", axis::Opaque) + .def("script.ir_builder.tir.AxisRemap", axis::Remap) + .def("script.ir_builder.tir.Serial", Serial) + .def("script.ir_builder.tir.Parallel", Parallel) + .def("script.ir_builder.tir.Vectorized", Vectorized) + .def("script.ir_builder.tir.Unroll", Unroll) + .def("script.ir_builder.tir.ThreadBinding", ThreadBinding) + .def("script.ir_builder.tir.Grid", Grid) + .def("script.ir_builder.tir.Assert", Assert) + .def("script.ir_builder.tir.LetStmt", LetStmt) + .def("script.ir_builder.tir.LegacyLetStmt", LegacyLetStmt) + .def("script.ir_builder.tir.Allocate", Allocate) + .def("script.ir_builder.tir.AllocateConst", AllocateConst) + .def("script.ir_builder.tir.Realize", Realize) + .def("script.ir_builder.tir.Attr", Attr) + .def("script.ir_builder.tir.While", While) + .def("script.ir_builder.tir.If", If) + .def("script.ir_builder.tir.Then", Then) + .def("script.ir_builder.tir.Else", Else) + .def("script.ir_builder.tir.DeclBuffer", DeclBuffer) + .def("script.ir_builder.tir.LaunchThread", + [](ffi::Variant thread_tag_or_var, PrimExpr extent) { + if (auto var = thread_tag_or_var.as()) { + return LaunchThread(var.value(), extent); + } else if (auto str = thread_tag_or_var.as()) { + return LaunchThread(str.value(), extent); + } else { + LOG(FATAL) << "ValueError: Unexpected type for TIR LaunchThread: " + << thread_tag_or_var.GetTypeKey(); + throw; + } + }) + .def("script.ir_builder.tir.EnvThread", EnvThread) + .def("script.ir_builder.tir.BufferStore", BufferStore) + .def("script.ir_builder.tir.Evaluate", Evaluate) + .def("script.ir_builder.tir.Ptr", Ptr); +}); #define TVM_TMP_STR(x) #x @@ -788,55 +787,93 @@ TVM_FFI_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float); TVM_FFI_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt); TVM_FFI_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.BFloat16").set_body_typed(BFloat16); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.ir_builder.tir.BFloat16", BFloat16); +}); TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.BFloat16", BFloat16); // Float8 variants -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E3M4").set_body_typed(Float8E3M4); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.ir_builder.tir.Float8E3M4", Float8E3M4); +}); TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E3M4", Float8E3M4); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E4M3").set_body_typed(Float8E4M3); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.ir_builder.tir.Float8E4M3", Float8E4M3); +}); TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3", Float8E4M3); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E4M3B11FNUZ") - .set_body_typed(Float8E4M3B11FNUZ); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.ir_builder.tir.Float8E4M3B11FNUZ", Float8E4M3B11FNUZ); +}); TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3B11FNUZ", Float8E4M3B11FNUZ); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E4M3FN").set_body_typed(Float8E4M3FN); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN); +}); TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E4M3FNUZ").set_body_typed(Float8E4M3FNUZ); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.ir_builder.tir.Float8E4M3FNUZ", Float8E4M3FNUZ); +}); TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FNUZ", Float8E4M3FNUZ); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E5M2").set_body_typed(Float8E5M2); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.ir_builder.tir.Float8E5M2", Float8E5M2); +}); TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2", Float8E5M2); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E5M2FNUZ").set_body_typed(Float8E5M2FNUZ); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.ir_builder.tir.Float8E5M2FNUZ", Float8E5M2FNUZ); +}); TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2FNUZ", Float8E5M2FNUZ); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E8M0FNU").set_body_typed(Float8E8M0FNU); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.ir_builder.tir.Float8E8M0FNU", Float8E8M0FNU); +}); TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E8M0FNU", Float8E8M0FNU); // Float6 variants -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float6E2M3FN").set_body_typed(Float6E2M3FN); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.ir_builder.tir.Float6E2M3FN", Float6E2M3FN); +}); TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float6E2M3FN", Float6E2M3FN); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float6E3M2FN").set_body_typed(Float6E3M2FN); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.ir_builder.tir.Float6E3M2FN", Float6E3M2FN); +}); TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float6E3M2FN", Float6E3M2FN); // Float4 variant -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float4E2M1FN").set_body_typed(Float4E2M1FN); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN); +}); TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.TensormapHandle").set_body_typed(TensormapHandle); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void); - -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.min") - .set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::min(a, b); }); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.max") - .set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::max(a, b); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("script.ir_builder.tir.Boolean", Boolean) + .def("script.ir_builder.tir.Handle", Handle) + .def("script.ir_builder.tir.TensormapHandle", TensormapHandle) + .def("script.ir_builder.tir.Void", Void) + .def("script.ir_builder.tir.min", + [](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::min(a, b); }) + .def("script.ir_builder.tir.max", + [](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::max(a, b); }); +}); } // namespace tir } // namespace ir_builder } // namespace script diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index 1665b9d88b12..16eb81d74eaa 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -18,6 +18,7 @@ */ #include #include +#include #include #include @@ -264,156 +265,221 @@ DocStringDoc::DocStringDoc(String docs) { } TVM_REGISTER_NODE_TYPE(DocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.DocSetSourcePaths") - .set_body_typed([](Doc doc, Array source_paths) { - doc->source_paths = source_paths; - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "script.printer.DocSetSourcePaths", + [](Doc doc, Array source_paths) { doc->source_paths = source_paths; }); +}); TVM_REGISTER_NODE_TYPE(ExprDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.ExprDocAttr") - .set_body_method(&ExprDocNode::Attr); -TVM_FFI_REGISTER_GLOBAL("script.printer.ExprDocIndex").set_body_method(&ExprDocNode::operator[]); -TVM_FFI_REGISTER_GLOBAL("script.printer.ExprDocCall") - .set_body_method, Array, Array>( - &ExprDocNode::Call); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("script.printer.ExprDocAttr", &ExprDocNode::Attr) + .def_method("script.printer.ExprDocIndex", &ExprDocNode::operator[]) + .def_method( + "script.printer.ExprDocCall", + [](ExprDoc doc, Array args, Array kwargs_keys, + Array kwargs_values) { return doc->Call(args, kwargs_keys, kwargs_values); }); +}); TVM_REGISTER_NODE_TYPE(StmtDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.StmtDocSetComment") - .set_body_typed([](StmtDoc doc, Optional comment) { doc->comment = comment; }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.StmtDocSetComment", + [](StmtDoc doc, Optional comment) { doc->comment = comment; }); +}); TVM_REGISTER_NODE_TYPE(StmtBlockDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array stmts) { - return StmtBlockDoc(stmts); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.StmtBlockDoc", + [](Array stmts) { return StmtBlockDoc(stmts); }); }); TVM_REGISTER_NODE_TYPE(LiteralDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None); -TVM_FFI_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int); -TVM_FFI_REGISTER_GLOBAL("script.printer.LiteralDocBoolean").set_body_typed(LiteralDoc::Boolean); -TVM_FFI_REGISTER_GLOBAL("script.printer.LiteralDocFloat").set_body_typed(LiteralDoc::Float); -TVM_FFI_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("script.printer.LiteralDocNone", LiteralDoc::None) + .def("script.printer.LiteralDocInt", LiteralDoc::Int) + .def("script.printer.LiteralDocBoolean", LiteralDoc::Boolean) + .def("script.printer.LiteralDocFloat", LiteralDoc::Float) + .def("script.printer.LiteralDocStr", LiteralDoc::Str); +}); TVM_REGISTER_NODE_TYPE(IdDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.IdDoc").set_body_typed([](String name) { - return IdDoc(name); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.IdDoc", [](String name) { return IdDoc(name); }); }); TVM_REGISTER_NODE_TYPE(AttrAccessDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.AttrAccessDoc") - .set_body_typed([](ExprDoc value, String attr) { return AttrAccessDoc(value, attr); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.AttrAccessDoc", + [](ExprDoc value, String attr) { return AttrAccessDoc(value, attr); }); +}); TVM_REGISTER_NODE_TYPE(IndexDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.IndexDoc") - .set_body_typed([](ExprDoc value, Array indices) { return IndexDoc(value, indices); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.IndexDoc", + [](ExprDoc value, Array indices) { return IndexDoc(value, indices); }); +}); TVM_REGISTER_NODE_TYPE(CallDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.CallDoc") - .set_body_typed([](ExprDoc callee, // - Array args, // - Array kwargs_keys, // - Array kwargs_values) { - return CallDoc(callee, args, kwargs_keys, kwargs_values); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.CallDoc", [](ExprDoc callee, // + Array args, // + Array kwargs_keys, // + Array kwargs_values) { + return CallDoc(callee, args, kwargs_keys, kwargs_values); + }); +}); TVM_REGISTER_NODE_TYPE(OperationDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.OperationDoc") - .set_body_typed([](int32_t kind, Array operands) { - return OperationDoc(OperationDocNode::Kind(kind), operands); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.OperationDoc", [](int32_t kind, Array operands) { + return OperationDoc(OperationDocNode::Kind(kind), operands); + }); +}); TVM_REGISTER_NODE_TYPE(LambdaDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.LambdaDoc") - .set_body_typed([](Array args, ExprDoc body) { return LambdaDoc(args, body); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.LambdaDoc", + [](Array args, ExprDoc body) { return LambdaDoc(args, body); }); +}); TVM_REGISTER_NODE_TYPE(TupleDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.TupleDoc").set_body_typed([](Array elements) { - return TupleDoc(elements); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.TupleDoc", + [](Array elements) { return TupleDoc(elements); }); }); TVM_REGISTER_NODE_TYPE(ListDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.ListDoc").set_body_typed([](Array elements) { - return ListDoc(elements); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.ListDoc", + [](Array elements) { return ListDoc(elements); }); }); TVM_REGISTER_NODE_TYPE(DictDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.DictDoc") - .set_body_typed([](Array keys, Array values) { - return DictDoc(keys, values); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.DictDoc", [](Array keys, Array values) { + return DictDoc(keys, values); + }); +}); TVM_REGISTER_NODE_TYPE(SliceDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.SliceDoc") - .set_body_typed([](Optional start, Optional stop, Optional step) { - return SliceDoc(start, stop, step); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.SliceDoc", + [](Optional start, Optional stop, + Optional step) { return SliceDoc(start, stop, step); }); +}); TVM_REGISTER_NODE_TYPE(AssignDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.AssignDoc") - .set_body_typed([](ExprDoc lhs, Optional rhs, Optional annotation) { - return AssignDoc(lhs, rhs, annotation); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.AssignDoc", + [](ExprDoc lhs, Optional rhs, Optional annotation) { + return AssignDoc(lhs, rhs, annotation); + }); +}); TVM_REGISTER_NODE_TYPE(IfDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.IfDoc") - .set_body_typed([](ExprDoc predicate, Array then_branch, Array else_branch) { - return IfDoc(predicate, then_branch, else_branch); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.IfDoc", [](ExprDoc predicate, Array then_branch, + Array else_branch) { + return IfDoc(predicate, then_branch, else_branch); + }); +}); TVM_REGISTER_NODE_TYPE(WhileDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.WhileDoc") - .set_body_typed([](ExprDoc predicate, Array body) { - return WhileDoc(predicate, body); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.WhileDoc", [](ExprDoc predicate, Array body) { + return WhileDoc(predicate, body); + }); +}); TVM_REGISTER_NODE_TYPE(ForDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.ForDoc") - .set_body_typed([](ExprDoc lhs, ExprDoc rhs, Array body) { - return ForDoc(lhs, rhs, body); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.ForDoc", [](ExprDoc lhs, ExprDoc rhs, Array body) { + return ForDoc(lhs, rhs, body); + }); +}); TVM_REGISTER_NODE_TYPE(ScopeDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.ScopeDoc") - .set_body_typed([](Optional lhs, ExprDoc rhs, Array body) { - return ScopeDoc(lhs, rhs, body); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.ScopeDoc", + [](Optional lhs, ExprDoc rhs, Array body) { + return ScopeDoc(lhs, rhs, body); + }); +}); TVM_REGISTER_NODE_TYPE(ExprStmtDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.ExprStmtDoc").set_body_typed([](ExprDoc expr) { - return ExprStmtDoc(expr); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.ExprStmtDoc", + [](ExprDoc expr) { return ExprStmtDoc(expr); }); }); TVM_REGISTER_NODE_TYPE(AssertDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.AssertDoc") - .set_body_typed([](ExprDoc test, Optional msg = std::nullopt) { - return AssertDoc(test, msg); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "script.printer.AssertDoc", + [](ExprDoc test, Optional msg = std::nullopt) { return AssertDoc(test, msg); }); +}); TVM_REGISTER_NODE_TYPE(ReturnDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.ReturnDoc").set_body_typed([](ExprDoc value) { - return ReturnDoc(value); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.ReturnDoc", [](ExprDoc value) { return ReturnDoc(value); }); }); TVM_REGISTER_NODE_TYPE(FunctionDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.FunctionDoc") - .set_body_typed([](IdDoc name, Array args, Array decorators, - Optional return_type, Array body) { - return FunctionDoc(name, args, decorators, return_type, body); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.FunctionDoc", + [](IdDoc name, Array args, Array decorators, + Optional return_type, Array body) { + return FunctionDoc(name, args, decorators, return_type, body); + }); +}); TVM_REGISTER_NODE_TYPE(ClassDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.ClassDoc") - .set_body_typed([](IdDoc name, Array decorators, Array body) { - return ClassDoc(name, decorators, body); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.ClassDoc", + [](IdDoc name, Array decorators, Array body) { + return ClassDoc(name, decorators, body); + }); +}); TVM_REGISTER_NODE_TYPE(CommentDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.CommentDoc").set_body_typed([](String comment) { - return CommentDoc(comment); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.CommentDoc", + [](String comment) { return CommentDoc(comment); }); }); TVM_REGISTER_NODE_TYPE(DocStringDocNode); -TVM_FFI_REGISTER_GLOBAL("script.printer.DocStringDoc").set_body_typed([](String docs) { - return DocStringDoc(docs); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.DocStringDoc", + [](String docs) { return DocStringDoc(docs); }); }); } // namespace printer diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index 85b5b755d253..e4ca2ad510fc 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include @@ -727,7 +728,10 @@ String DocToPythonScript(Doc doc, const PrinterConfig& cfg) { return result.substr(0, last_space); } -TVM_FFI_REGISTER_GLOBAL("script.printer.DocToPythonScript").set_body_typed(DocToPythonScript); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.DocToPythonScript", DocToPythonScript); +}); } // namespace printer } // namespace script diff --git a/src/script/printer/relax/type.cc b/src/script/printer/relax/type.cc index 3d7abe821745..ea1ed7e698be 100644 --- a/src/script/printer/relax/type.cc +++ b/src/script/printer/relax/type.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "./utils.h" namespace tvm { @@ -82,7 +84,10 @@ TVM_SCRIPT_REPR(relax::ShapeTypeNode, ReprPrintRelax); TVM_SCRIPT_REPR(relax::ObjectTypeNode, ReprPrintRelax); TVM_SCRIPT_REPR(relax::TensorTypeNode, ReprPrintRelax); TVM_SCRIPT_REPR(relax::PackedFuncTypeNode, ReprPrintRelax); -TVM_FFI_REGISTER_GLOBAL("script.printer.ReprPrintRelax").set_body_typed(ReprPrintRelax); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.printer.ReprPrintRelax", ReprPrintRelax); +}); } // namespace printer } // namespace script diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 106e7f985b65..7162bffc0173 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -60,57 +60,55 @@ TVM_FFI_STATIC_INIT_BLOCK({ TestAttrs::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(TestAttrs); -TVM_FFI_REGISTER_GLOBAL("testing.GetShapeSize").set_body_typed([](ffi::Shape shape) { - return static_cast(shape.size()); -}); - -TVM_FFI_REGISTER_GLOBAL("testing.GetShapeElem").set_body_typed([](ffi::Shape shape, int idx) { - ICHECK_LT(idx, shape.size()); - return shape[idx]; -}); - -TVM_FFI_REGISTER_GLOBAL("testing.test_wrap_callback") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - ffi::Function pf = args[0].cast(); - *ret = ffi::TypedFunction([pf]() { pf(); }); - }); - -TVM_FFI_REGISTER_GLOBAL("testing.test_wrap_callback_suppress_err") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - ffi::Function pf = args[0].cast(); - auto result = ffi::TypedFunction([pf]() { - try { - pf(); - } catch (std::exception& err) { - } +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("testing.GetShapeSize", + [](ffi::Shape shape) { return static_cast(shape.size()); }) + .def("testing.GetShapeElem", + [](ffi::Shape shape, int idx) { + ICHECK_LT(idx, shape.size()); + return shape[idx]; + }) + .def_packed("testing.test_wrap_callback", + [](ffi::PackedArgs args, ffi::Any* ret) { + ffi::Function pf = args[0].cast(); + *ret = ffi::TypedFunction([pf]() { pf(); }); + }) + .def_packed("testing.test_wrap_callback_suppress_err", + [](ffi::PackedArgs args, ffi::Any* ret) { + ffi::Function pf = args[0].cast(); + auto result = ffi::TypedFunction([pf]() { + try { + pf(); + } catch (std::exception& err) { + } + }); + *ret = result; + }) + .def_packed("testing.test_check_eq_callback", + [](ffi::PackedArgs args, ffi::Any* ret) { + auto msg = args[0].cast(); + *ret = ffi::TypedFunction( + [msg](int x, int y) { CHECK_EQ(x, y) << msg; }); + }) + .def_packed("testing.device_test", + [](ffi::PackedArgs args, ffi::Any* ret) { + auto dev = args[0].cast(); + int dtype = args[1].cast(); + int did = args[2].cast(); + CHECK_EQ(static_cast(dev.device_type), dtype); + CHECK_EQ(static_cast(dev.device_id), did); + *ret = dev; + }) + .def_packed("testing.identity_cpp", [](ffi::PackedArgs args, ffi::Any* ret) { + const auto identity_func = tvm::ffi::Function::GetGlobal("testing.identity_py"); + ICHECK(identity_func.has_value()) + << "AttributeError: \"testing.identity_py\" is not registered. Please check " + "if the python module is properly loaded"; + *ret = (*identity_func)(args[0]); }); - *ret = result; - }); - -TVM_FFI_REGISTER_GLOBAL("testing.test_check_eq_callback") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto msg = args[0].cast(); - *ret = ffi::TypedFunction([msg](int x, int y) { CHECK_EQ(x, y) << msg; }); - }); - -TVM_FFI_REGISTER_GLOBAL("testing.device_test") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto dev = args[0].cast(); - int dtype = args[1].cast(); - int did = args[2].cast(); - CHECK_EQ(static_cast(dev.device_type), dtype); - CHECK_EQ(static_cast(dev.device_id), did); - *ret = dev; - }); - -TVM_FFI_REGISTER_GLOBAL("testing.identity_cpp") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - const auto identity_func = tvm::ffi::Function::GetGlobal("testing.identity_py"); - ICHECK(identity_func.has_value()) - << "AttributeError: \"testing.identity_py\" is not registered. Please check " - "if the python module is properly loaded"; - *ret = (*identity_func)(args[0]); - }); +}); // in src/api_test.cc void ErrorTest(int x, int y) { @@ -122,7 +120,10 @@ void ErrorTest(int x, int y) { } } -TVM_FFI_REGISTER_GLOBAL("testing.ErrorTest").set_body_typed(ErrorTest); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("testing.ErrorTest", ErrorTest); +}); class FrontendTestModuleNode : public runtime::ModuleNode { public: @@ -162,77 +163,65 @@ runtime::Module NewFrontendTestModule() { return runtime::Module(n); } -TVM_FFI_REGISTER_GLOBAL("testing.FrontendTestModule").set_body_typed(NewFrontendTestModule); - -TVM_FFI_REGISTER_GLOBAL("testing.sleep_in_ffi").set_body_typed([](double timeout) { - std::chrono::duration duration(static_cast(timeout * 1e9)); - std::this_thread::sleep_for(duration); -}); - -TVM_FFI_REGISTER_GLOBAL("testing.ReturnsVariant") - .set_body_typed([](int x) -> Variant { - if (x % 2 == 0) { - return IntImm(DataType::Int(64), x / 2); - } else { - return String("argument was odd"); - } - }); - -TVM_FFI_REGISTER_GLOBAL("testing.AcceptsVariant") - .set_body_typed([](Variant arg) -> String { - if (auto opt_str = arg.as()) { - return opt_str.value()->GetTypeKey(); - } else { - return arg.get()->GetTypeKey(); - } - }); - -TVM_FFI_REGISTER_GLOBAL("testing.AcceptsBool").set_body_typed([](bool arg) -> bool { return arg; }); - -TVM_FFI_REGISTER_GLOBAL("testing.AcceptsInt").set_body_typed([](int arg) -> int { return arg; }); - -TVM_FFI_REGISTER_GLOBAL("testing.AcceptsObjectRefArray").set_body_typed([](Array arg) -> Any { - return arg[0]; -}); - -TVM_FFI_REGISTER_GLOBAL("testing.AcceptsMapReturnsValue") - .set_body_typed([](Map map, Any key) -> Any { return map[key]; }); - -TVM_FFI_REGISTER_GLOBAL("testing.AcceptsMapReturnsMap") - .set_body_typed([](Map map) -> ObjectRef { return map; }); - -TVM_FFI_REGISTER_GLOBAL("testing.AcceptsPrimExpr").set_body_typed([](PrimExpr expr) -> ObjectRef { - return expr; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("testing.FrontendTestModule", NewFrontendTestModule) + .def( + "testing.sleep_in_ffi", + [](double timeout) { + std::chrono::duration duration(static_cast(timeout * 1e9)); + std::this_thread::sleep_for(duration); + }) + .def("testing.ReturnsVariant", + [](int x) -> Variant { + if (x % 2 == 0) { + return IntImm(DataType::Int(64), x / 2); + } else { + return String("argument was odd"); + } + }) + .def("testing.AcceptsVariant", + [](Variant arg) -> String { + if (auto opt_str = arg.as()) { + return opt_str.value()->GetTypeKey(); + } else { + return arg.get()->GetTypeKey(); + } + }) + .def("testing.AcceptsBool", [](bool arg) -> bool { return arg; }) + .def("testing.AcceptsInt", [](int arg) -> int { return arg; }) + .def("testing.AcceptsObjectRefArray", [](Array arg) -> Any { return arg[0]; }) + .def("testing.AcceptsMapReturnsValue", + [](Map map, Any key) -> Any { return map[key]; }) + .def("testing.AcceptsMapReturnsMap", [](Map map) -> ObjectRef { return map; }) + .def("testing.AcceptsPrimExpr", [](PrimExpr expr) -> ObjectRef { return expr; }) + .def("testing.AcceptsArrayOfPrimExpr", + [](Array arr) -> ObjectRef { + for (ObjectRef item : arr) { + CHECK(item->IsInstance()) << "Array contained " << item->GetTypeKey() + << " when it should contain PrimExpr"; + } + return arr; + }) + .def("testing.AcceptsArrayOfVariant", + [](Array> arr) -> ObjectRef { + for (auto item : arr) { + CHECK(item.as() || item.as()) + << "Array should contain either PrimExpr or ffi::Function"; + } + return arr; + }) + .def("testing.AcceptsMapOfPrimExpr", [](Map map) -> ObjectRef { + for (const auto& kv : map) { + ObjectRef value = kv.second; + CHECK(value->IsInstance()) + << "Map contained " << value->GetTypeKey() << " when it should contain PrimExpr"; + } + return map; + }); }); -TVM_FFI_REGISTER_GLOBAL("testing.AcceptsArrayOfPrimExpr") - .set_body_typed([](Array arr) -> ObjectRef { - for (ObjectRef item : arr) { - CHECK(item->IsInstance()) - << "Array contained " << item->GetTypeKey() << " when it should contain PrimExpr"; - } - return arr; - }); - -TVM_FFI_REGISTER_GLOBAL("testing.AcceptsArrayOfVariant") - .set_body_typed([](Array> arr) -> ObjectRef { - for (auto item : arr) { - CHECK(item.as() || item.as()) - << "Array should contain either PrimExpr or ffi::Function"; - } - return arr; - }); - -TVM_FFI_REGISTER_GLOBAL("testing.AcceptsMapOfPrimExpr") - .set_body_typed([](Map map) -> ObjectRef { - for (const auto& kv : map) { - ObjectRef value = kv.second; - CHECK(value->IsInstance()) - << "Map contained " << value->GetTypeKey() << " when it should contain PrimExpr"; - } - return map; - }); - /** * Simple event logger that can be used for testing purposes */ @@ -272,21 +261,20 @@ class TestingEventLogger { std::vector entries_; }; -TVM_FFI_REGISTER_GLOBAL("testing.record_event") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - if (args.size() != 0 && args[0].try_cast()) { - TestingEventLogger::ThreadLocal()->Record(args[0].cast()); - } else { - TestingEventLogger::ThreadLocal()->Record("X"); - } - }); - -TVM_FFI_REGISTER_GLOBAL("testing.reset_events") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - TestingEventLogger::ThreadLocal()->Reset(); - }); - -TVM_FFI_REGISTER_GLOBAL("testing.dump_events").set_body_typed([]() { - TestingEventLogger::ThreadLocal()->Dump(); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("testing.record_event", + [](ffi::PackedArgs args, ffi::Any* rv) { + if (args.size() != 0 && args[0].try_cast()) { + TestingEventLogger::ThreadLocal()->Record(args[0].cast()); + } else { + TestingEventLogger::ThreadLocal()->Record("X"); + } + }) + .def_packed( + "testing.reset_events", + [](ffi::PackedArgs args, ffi::Any* rv) { TestingEventLogger::ThreadLocal()->Reset(); }) + .def("testing.dump_events", []() { TestingEventLogger::ThreadLocal()->Dump(); }); }); } // namespace tvm diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index e31fddbcb058..1f274ee894e0 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -18,6 +18,7 @@ */ #include #include +#include #include #include @@ -365,6 +366,9 @@ TVM_DLL ffi::Map GetLibInfo() { return result; } -TVM_FFI_REGISTER_GLOBAL("support.GetLibInfo").set_body_typed(GetLibInfo); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("support.GetLibInfo", GetLibInfo); +}); } // namespace tvm diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 9d3f1529c81c..d2d69abb7801 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -361,31 +362,35 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib, .cast(); } -TVM_FFI_REGISTER_GLOBAL("target.Build").set_body_typed(Build); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("target.Build", Build); +}); // Export a few auxiliary function to the runtime namespace. -TVM_FFI_REGISTER_GLOBAL("runtime.ModuleImportsBlobName").set_body_typed([]() -> std::string { - return runtime::symbol::tvm_ffi_library_bin; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.ModuleImportsBlobName", + []() -> std::string { return runtime::symbol::tvm_ffi_library_bin; }) + .def("runtime.ModulePackImportsToNDArray", + [](const runtime::Module& mod) { + std::string buffer = PackImportsToBytes(mod); + ffi::Shape::index_type size = buffer.size(); + DLDataType uchar; + uchar.code = kDLUInt; + uchar.bits = 8; + uchar.lanes = 1; + DLDevice dev; + dev.device_type = kDLCPU; + dev.device_id = 0; + auto array = runtime::NDArray::Empty({size}, uchar, dev); + array.CopyFromBytes(buffer.data(), size); + return array; + }) + .def("runtime.ModulePackImportsToC", PackImportsToC) + .def("runtime.ModulePackImportsToLLVM", PackImportsToLLVM); }); -TVM_FFI_REGISTER_GLOBAL("runtime.ModulePackImportsToNDArray") - .set_body_typed([](const runtime::Module& mod) { - std::string buffer = PackImportsToBytes(mod); - ffi::Shape::index_type size = buffer.size(); - DLDataType uchar; - uchar.code = kDLUInt; - uchar.bits = 8; - uchar.lanes = 1; - DLDevice dev; - dev.device_type = kDLCPU; - dev.device_id = 0; - auto array = runtime::NDArray::Empty({size}, uchar, dev); - array.CopyFromBytes(buffer.data(), size); - return array; - }); - -TVM_FFI_REGISTER_GLOBAL("runtime.ModulePackImportsToC").set_body_typed(PackImportsToC); -TVM_FFI_REGISTER_GLOBAL("runtime.ModulePackImportsToLLVM").set_body_typed(PackImportsToLLVM); - } // namespace codegen } // namespace tvm diff --git a/src/target/datatype/registry.cc b/src/target/datatype/registry.cc index 88f96b6a707b..8e1844fd5165 100644 --- a/src/target/datatype/registry.cc +++ b/src/target/datatype/registry.cc @@ -19,6 +19,7 @@ #include "registry.h" #include +#include #include namespace tvm { @@ -27,26 +28,26 @@ namespace datatype { using ffi::Any; using ffi::PackedArgs; -TVM_FFI_REGISTER_GLOBAL("dtype.register_custom_type") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - datatype::Registry::Global()->Register(args[0].cast(), - static_cast(args[1].cast())); - }); - -TVM_FFI_REGISTER_GLOBAL("dtype.get_custom_type_code") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - *ret = datatype::Registry::Global()->GetTypeCode(args[0].cast()); - }); - -TVM_FFI_REGISTER_GLOBAL("dtype.get_custom_type_name") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - *ret = Registry::Global()->GetTypeName(args[0].cast()); - }); - -TVM_FFI_REGISTER_GLOBAL("runtime._datatype_get_type_registered") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - *ret = Registry::Global()->GetTypeRegistered(args[0].cast()); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("dtype.register_custom_type", + [](ffi::PackedArgs args, ffi::Any* ret) { + datatype::Registry::Global()->Register( + args[0].cast(), static_cast(args[1].cast())); + }) + .def_packed("dtype.get_custom_type_code", + [](ffi::PackedArgs args, ffi::Any* ret) { + *ret = datatype::Registry::Global()->GetTypeCode(args[0].cast()); + }) + .def_packed("dtype.get_custom_type_name", + [](ffi::PackedArgs args, ffi::Any* ret) { + *ret = Registry::Global()->GetTypeName(args[0].cast()); + }) + .def_packed("runtime._datatype_get_type_registered", [](ffi::PackedArgs args, ffi::Any* ret) { + *ret = Registry::Global()->GetTypeRegistered(args[0].cast()); + }); +}); Registry* Registry::Global() { static Registry inst; diff --git a/src/target/llvm/codegen_aarch64.cc b/src/target/llvm/codegen_aarch64.cc index 9d968cdb6478..4c7554cf0024 100644 --- a/src/target/llvm/codegen_aarch64.cc +++ b/src/target/llvm/codegen_aarch64.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include "../../arith/scalable_expression.h" #include "codegen_cpu.h" @@ -106,10 +107,13 @@ void CodeGenAArch64::VisitStmt_(const AttrStmtNode* op) { this->VisitStmt(op->body); } -TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_aarch64") - .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { - *rv = static_cast(new CodeGenAArch64()); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("tvm.codegen.llvm.target_aarch64", + [](const ffi::PackedArgs& targs, ffi::Any* rv) { + *rv = static_cast(new CodeGenAArch64()); + }); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 048c4160b118..7a50396081a2 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -30,6 +30,7 @@ #include #include #include +#include #if TVM_LLVM_VERSION >= 100 #include #endif @@ -356,12 +357,14 @@ runtime::Module BuildAMDGPU(IRModule mod, Target target) { return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(mod), ll, assembly); } -TVM_FFI_REGISTER_GLOBAL("target.build.rocm").set_body_typed(BuildAMDGPU); - -TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_rocm") - .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { - *rv = static_cast(new CodeGenAMDGPU()); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("target.build.rocm", BuildAMDGPU) + .def_packed("tvm.codegen.llvm.target_rocm", [](const ffi::PackedArgs& targs, ffi::Any* rv) { + *rv = static_cast(new CodeGenAMDGPU()); + }); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index 03ef982d1308..93879bf1ec4f 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -25,6 +25,7 @@ #include #include +#include #if TVM_LLVM_VERSION >= 100 #include #endif @@ -132,10 +133,13 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt64_args); } -TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") - .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { - *rv = static_cast(new CodeGenARM()); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("tvm.codegen.llvm.target_arm", + [](const ffi::PackedArgs& targs, ffi::Any* rv) { + *rv = static_cast(new CodeGenARM()); + }); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index b16617e3d6bc..7f87b0c7fb72 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -43,6 +43,7 @@ #include #include #include +#include #if TVM_LLVM_VERSION >= 100 #include #endif @@ -1161,10 +1162,13 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { } } -TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_cpu") - .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { - *rv = static_cast(new CodeGenCPU()); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("tvm.codegen.llvm.target_cpu", + [](const ffi::PackedArgs& targs, ffi::Any* rv) { + *rv = static_cast(new CodeGenCPU()); + }); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index baf7497bc0d1..b4362bebac5d 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -29,6 +29,7 @@ #include #include #include +#include #if TVM_LLVM_VERSION >= 100 #include #endif @@ -589,12 +590,15 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { return HexagonModuleCreate(so_name, "so", ExtractFuncInfo(mod), asm_str, obj_str, ir_str, bc_str); } -TVM_FFI_REGISTER_GLOBAL("target.build.hexagon").set_body_typed(BuildHexagon); - -TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_hexagon") - .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { - *rv = static_cast(new CodeGenHexagon()); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("target.build.hexagon", BuildHexagon) + .def_packed("tvm.codegen.llvm.target_hexagon", + [](const ffi::PackedArgs& targs, ffi::Any* rv) { + *rv = static_cast(new CodeGenHexagon()); + }); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 8c8a9f2a4a89..02000353559f 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -27,6 +27,7 @@ #include #include #include +#include #if LLVM_VERSION_MAJOR >= 17 #include #else @@ -2350,28 +2351,25 @@ llvm::DIType* CodeGenLLVM::GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm) return nullptr; } -TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.GetDefaultTargetTriple") - .set_body_typed([]() -> std::string { return llvm::sys::getDefaultTargetTriple(); }); - -TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.GetProcessTriple").set_body_typed([]() -> std::string { - return llvm::sys::getProcessTriple(); -}); - -TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.GetHostCPUName").set_body_typed([]() -> std::string { - return llvm::sys::getHostCPUName().str(); -}); - -TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.GetHostCPUFeatures") - .set_body_typed([]() -> Map { +static void CodegenLLVMRegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tvm.codegen.llvm.GetDefaultTargetTriple", + []() -> std::string { return llvm::sys::getDefaultTargetTriple(); }) + .def("tvm.codegen.llvm.GetProcessTriple", + []() -> std::string { return llvm::sys::getProcessTriple(); }) + .def("tvm.codegen.llvm.GetHostCPUName", + []() -> std::string { return llvm::sys::getHostCPUName().str(); }) + .def("tvm.codegen.llvm.GetHostCPUFeatures", []() -> Map { #if TVM_LLVM_VERSION >= 190 - Map ret; - auto features = llvm::sys::getHostCPUFeatures(); - for (auto it = features.begin(); it != features.end(); ++it) { - std::string name = it->getKey().str(); - bool value = it->getValue(); - ret.Set(name, IntImm(DataType::Bool(), value)); - } - return ret; + Map ret; + auto features = llvm::sys::getHostCPUFeatures(); + for (auto it = features.begin(); it != features.end(); ++it) { + std::string name = it->getKey().str(); + bool value = it->getValue(); + ret.Set(name, IntImm(DataType::Bool(), value)); + } + return ret; #else llvm::StringMap features; if (llvm::sys::getHostCPUFeatures(features)) { @@ -2384,9 +2382,12 @@ TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.GetHostCPUFeatures") return ret; } #endif - LOG(WARNING) << "Current version of LLVM does not support feature detection on your CPU"; - return {}; - }); + LOG(WARNING) << "Current version of LLVM does not support feature detection on your CPU"; + return {}; + }); +} + +TVM_FFI_STATIC_INIT_BLOCK({ CodegenLLVMRegisterReflection(); }); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index a0ffb5a1ce10..f988457c1253 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -30,6 +30,7 @@ #include #include #include +#include #if TVM_LLVM_VERSION >= 100 #include #endif @@ -368,12 +369,14 @@ runtime::Module BuildNVPTX(IRModule mod, Target target) { return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(mod), ll); } -TVM_FFI_REGISTER_GLOBAL("target.build.nvptx").set_body_typed(BuildNVPTX); - -TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_nvptx") - .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { - *rv = static_cast(new CodeGenNVPTX()); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("target.build.nvptx", BuildNVPTX) + .def_packed("tvm.codegen.llvm.target_nvptx", [](const ffi::PackedArgs& targs, ffi::Any* rv) { + *rv = static_cast(new CodeGenNVPTX()); + }); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index 435b453d49ba..a8cb5b39c538 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -26,6 +26,7 @@ #include #include #include +#include #if TVM_LLVM_VERSION >= 100 #include #endif @@ -132,10 +133,13 @@ llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intr return CreateVecSlice(CreateVecConcat(split_results), 0, num_elems); } -TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64") - .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { - *rv = static_cast(new CodeGenX86_64()); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("tvm.codegen.llvm.target_x86-64", + [](const ffi::PackedArgs& targs, ffi::Any* rv) { + *rv = static_cast(new CodeGenX86_64()); + }); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 552c2b74c64e..42881016baa0 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -31,6 +31,7 @@ #include #include #include +#include #if _WIN32 #include #include @@ -620,40 +621,41 @@ void* LLVMModuleNode::GetFunctionAddr(const std::string& name, return nullptr; } -TVM_FFI_REGISTER_GLOBAL("target.build.llvm") - .set_body_typed([](IRModule mod, Target target) -> runtime::Module { - auto n = make_object(); - n->Init(mod, target); - return runtime::Module(n); - }); - -TVM_FFI_REGISTER_GLOBAL("codegen.LLVMModuleCreate") - .set_body_typed([](std::string target_str, std::string module_name) -> runtime::Module { - auto llvm_instance = std::make_unique(); - With llvm_target(*llvm_instance, target_str); - auto n = make_object(); - // Generate a LLVM module from an input target string - auto module = std::make_unique(module_name, *llvm_target->GetContext()); - llvm_target->SetTargetMetadata(module.get()); - module->setTargetTriple(llvm_target->GetTargetTriple()); - module->setDataLayout(llvm_target->GetOrCreateTargetMachine()->createDataLayout()); - n->Init(std::move(module), std::move(llvm_instance)); - n->SetJITEngine(llvm_target->GetJITEngine()); - return runtime::Module(n); - }); - -TVM_FFI_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id") - .set_body_typed([](std::string name) -> int64_t { +static void LLVMReflectionRegister() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("target.build.llvm", + [](IRModule mod, Target target) -> runtime::Module { + auto n = make_object(); + n->Init(mod, target); + return runtime::Module(n); + }) + .def("codegen.LLVMModuleCreate", + [](std::string target_str, std::string module_name) -> runtime::Module { + auto llvm_instance = std::make_unique(); + With llvm_target(*llvm_instance, target_str); + auto n = make_object(); + // Generate a LLVM module from an input target string + auto module = std::make_unique(module_name, *llvm_target->GetContext()); + llvm_target->SetTargetMetadata(module.get()); + module->setTargetTriple(llvm_target->GetTargetTriple()); + module->setDataLayout(llvm_target->GetOrCreateTargetMachine()->createDataLayout()); + n->Init(std::move(module), std::move(llvm_instance)); + n->SetJITEngine(llvm_target->GetJITEngine()); + return runtime::Module(n); + }) + .def("target.llvm_lookup_intrinsic_id", + [](std::string name) -> int64_t { #if TVM_LLVM_VERSION >= 200 - return static_cast(llvm::Intrinsic::lookupIntrinsicID(name)); + return static_cast(llvm::Intrinsic::lookupIntrinsicID(name)); #else return static_cast(llvm::Function::lookupIntrinsicID(name)); #endif - }); - -TVM_FFI_REGISTER_GLOBAL("target.llvm_get_intrinsic_name").set_body_typed([](int64_t id) -> String { + }) + .def("target.llvm_get_intrinsic_name", + [](int64_t id) -> String { #if TVM_LLVM_VERSION >= 130 - return std::string(llvm::Intrinsic::getBaseName(static_cast(id))); + return std::string(llvm::Intrinsic::getBaseName(static_cast(id))); #elif TVM_LLVM_VERSION >= 40 // This is the version of Intrinsic::getName that works for overloaded // intrinsics. Helpfully, if we provide no types to this function, it @@ -664,142 +666,130 @@ TVM_FFI_REGISTER_GLOBAL("target.llvm_get_intrinsic_name").set_body_typed([](int6 // Nothing to do, just return the intrinsic id number return std::to_string(id); #endif -}); - -TVM_FFI_REGISTER_GLOBAL("target.llvm_get_system_x86_vendor").set_body_typed([]() -> String { + }) + .def("target.llvm_get_system_x86_vendor", + []() -> String { #if TVM_LLVM_VERSION >= 120 #if defined(__i386__) || defined(_M_IX86) || defined(__x86_64__) || defined(_M_X64) - using namespace llvm::sys::detail::x86; - const auto x86_sign = getVendorSignature(); - if (x86_sign == VendorSignatures::GENUINE_INTEL) - return "intel"; - else if (x86_sign == VendorSignatures::AUTHENTIC_AMD) - return "amd"; - else if (x86_sign == VendorSignatures::UNKNOWN) - return "unknown"; + using namespace llvm::sys::detail::x86; + const auto x86_sign = getVendorSignature(); + if (x86_sign == VendorSignatures::GENUINE_INTEL) + return "intel"; + else if (x86_sign == VendorSignatures::AUTHENTIC_AMD) + return "amd"; + else if (x86_sign == VendorSignatures::UNKNOWN) + return "unknown"; #endif #endif - return "unimplemented"; -}); - -TVM_FFI_REGISTER_GLOBAL("target.llvm_get_vector_width") - .set_body_typed([](const Target& target) -> int { - auto use_target = target.defined() ? target : Target::Current(false); - // ignore non "llvm" target - if (target.defined()) { - if (target->kind->name != "llvm") { - return -1; - } - } - auto llvm_instance = std::make_unique(); - LLVMTargetInfo llvm_backend(*llvm_instance, use_target); - return llvm_backend.GetVectorWidth(); - }); - -TVM_FFI_REGISTER_GLOBAL("target.llvm_get_system_triple").set_body_typed([]() -> String { - return llvm::sys::getDefaultTargetTriple(); -}); - -TVM_FFI_REGISTER_GLOBAL("target.llvm_get_system_cpu").set_body_typed([]() -> String { - return llvm::sys::getHostCPUName().str(); -}); - -TVM_FFI_REGISTER_GLOBAL("target.llvm_get_targets").set_body_typed([]() -> Array { - auto llvm_instance = std::make_unique(); - LLVMTargetInfo llvm_backend(*llvm_instance, "llvm"); - return llvm_backend.GetAllLLVMTargets(); -}); - -TVM_FFI_REGISTER_GLOBAL("target.llvm_get_cpu_archlist") - .set_body_typed([](const Target& target) -> Array { - auto use_target = target.defined() ? target : Target::Current(false); - // ignore non "llvm" target - if (target.defined()) { - if (target->kind->name != "llvm") { - return Array{}; - } - } - auto llvm_instance = std::make_unique(); - LLVMTargetInfo llvm_backend(*llvm_instance, use_target); - return llvm_backend.GetAllLLVMTargetArches(); - }); - -TVM_FFI_REGISTER_GLOBAL("target.llvm_get_cpu_features") - .set_body_typed([](const Target& target) -> Map { - auto use_target = target.defined() ? target : Target::Current(false); - // ignore non "llvm" target - if (target.defined()) { - if (target->kind->name != "llvm") { - return {}; - } - } - auto llvm_instance = std::make_unique(); - LLVMTargetInfo llvm_backend(*llvm_instance, use_target); - return llvm_backend.GetAllLLVMCpuFeatures(); - }); - -TVM_FFI_REGISTER_GLOBAL("target.llvm_cpu_has_feature") - .set_body_typed([](const String feature, const Target& target) -> bool { - auto use_target = target.defined() ? target : Target::Current(false); - // ignore non "llvm" target - if (target.defined()) { - if (target->kind->name != "llvm") { - return false; - } - } - auto llvm_instance = std::make_unique(); - LLVMTargetInfo llvm_backend(*llvm_instance, use_target); - auto cpu_features = llvm_backend.GetAllLLVMCpuFeatures(); - bool has_feature = cpu_features.find(feature) != cpu_features.end(); - return has_feature; - }); - -TVM_FFI_REGISTER_GLOBAL("target.target_has_feature") - .set_body_typed([](const String feature, const Target& target) -> bool { - auto use_target = target.defined() ? target : Target::Current(false); - // ignore non "llvm" target - if (target.defined()) { - if (target->kind->name != "llvm") { - return false; - } - } - auto llvm_instance = std::make_unique(); - LLVMTargetInfo llvm_target(*llvm_instance, use_target); - return llvm_target.TargetHasCPUFeature(feature); - }); - -TVM_FFI_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int { - return TVM_LLVM_VERSION / 10; -}); - -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_ll") - .set_body_typed([](std::string filename, std::string fmt) -> runtime::Module { - auto n = make_object(); - n->SetJITEngine("orcjit"); - n->LoadIR(filename); - return runtime::Module(n); - }); - -TVM_FFI_REGISTER_GLOBAL("codegen.llvm_target_enabled") - .set_body_typed([](std::string target_str) -> bool { - LLVMInstance llvm_instance; - auto* tm = With(llvm_instance, target_str) - ->GetOrCreateTargetMachine(/*allow_missing=*/true); - return tm != nullptr; - }); + return "unimplemented"; + }) + .def("target.llvm_get_vector_width", + [](const Target& target) -> int { + auto use_target = target.defined() ? target : Target::Current(false); + // ignore non "llvm" target + if (target.defined()) { + if (target->kind->name != "llvm") { + return -1; + } + } + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_backend(*llvm_instance, use_target); + return llvm_backend.GetVectorWidth(); + }) + .def("target.llvm_get_system_triple", + []() -> String { return llvm::sys::getDefaultTargetTriple(); }) + .def("target.llvm_get_system_cpu", + []() -> String { return llvm::sys::getHostCPUName().str(); }) + .def("target.llvm_get_targets", + []() -> Array { + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_backend(*llvm_instance, "llvm"); + return llvm_backend.GetAllLLVMTargets(); + }) + .def("target.llvm_get_cpu_archlist", + [](const Target& target) -> Array { + auto use_target = target.defined() ? target : Target::Current(false); + // ignore non "llvm" target + if (target.defined()) { + if (target->kind->name != "llvm") { + return Array{}; + } + } + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_backend(*llvm_instance, use_target); + return llvm_backend.GetAllLLVMTargetArches(); + }) + .def("target.llvm_get_cpu_features", + [](const Target& target) -> Map { + auto use_target = target.defined() ? target : Target::Current(false); + // ignore non "llvm" target + if (target.defined()) { + if (target->kind->name != "llvm") { + return {}; + } + } + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_backend(*llvm_instance, use_target); + return llvm_backend.GetAllLLVMCpuFeatures(); + }) + .def("target.llvm_cpu_has_feature", + [](const String feature, const Target& target) -> bool { + auto use_target = target.defined() ? target : Target::Current(false); + // ignore non "llvm" target + if (target.defined()) { + if (target->kind->name != "llvm") { + return false; + } + } + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_backend(*llvm_instance, use_target); + auto cpu_features = llvm_backend.GetAllLLVMCpuFeatures(); + bool has_feature = cpu_features.find(feature) != cpu_features.end(); + return has_feature; + }) + .def("target.target_has_feature", + [](const String feature, const Target& target) -> bool { + auto use_target = target.defined() ? target : Target::Current(false); + // ignore non "llvm" target + if (target.defined()) { + if (target->kind->name != "llvm") { + return false; + } + } + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_target(*llvm_instance, use_target); + return llvm_target.TargetHasCPUFeature(feature); + }) + .def("target.llvm_version_major", []() -> int { return TVM_LLVM_VERSION / 10; }) + .def("runtime.module.loadfile_ll", + [](std::string filename, std::string fmt) -> runtime::Module { + auto n = make_object(); + n->SetJITEngine("orcjit"); + n->LoadIR(filename); + return runtime::Module(n); + }) + .def("codegen.llvm_target_enabled", + [](std::string target_str) -> bool { + LLVMInstance llvm_instance; + auto* tm = With(llvm_instance, target_str) + ->GetOrCreateTargetMachine(/*allow_missing=*/true); + return tm != nullptr; + }) + .def("codegen.codegen_blob", + [](std::string data, bool system_lib, std::string llvm_target_string, + std::string c_symbol_prefix) -> runtime::Module { + auto n = make_object(); + auto llvm_instance = std::make_unique(); + With llvm_target(*llvm_instance, llvm_target_string); + std::unique_ptr blob = + CodeGenBlob(data, system_lib, llvm_target.get(), c_symbol_prefix); + n->Init(std::move(blob), std::move(llvm_instance)); + n->SetJITEngine(llvm_target->GetJITEngine()); + return runtime::Module(n); + }); +} -TVM_FFI_REGISTER_GLOBAL("codegen.codegen_blob") - .set_body_typed([](std::string data, bool system_lib, std::string llvm_target_string, - std::string c_symbol_prefix) -> runtime::Module { - auto n = make_object(); - auto llvm_instance = std::make_unique(); - With llvm_target(*llvm_instance, llvm_target_string); - std::unique_ptr blob = - CodeGenBlob(data, system_lib, llvm_target.get(), c_symbol_prefix); - n->Init(std::move(blob), std::move(llvm_instance)); - n->SetJITEngine(llvm_target->GetJITEngine()); - return runtime::Module(n); - }); +TVM_FFI_STATIC_INIT_BLOCK({ LLVMReflectionRegister(); }); } // namespace codegen } // namespace tvm diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index bdac0b8fb72d..9bdbca0a5c62 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -25,6 +25,7 @@ */ #if defined(__linux__) #include +#include #endif #include #include @@ -172,7 +173,10 @@ runtime::Module BuildCUDA(IRModule mod, Target target) { return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code); } -TVM_FFI_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("target.build.cuda", BuildCUDA); +}); TVM_REGISTER_PASS_CONFIG_OPTION("cuda.kernels_output_dir", String); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 776edbb724ab..46cca620c7e1 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -22,6 +22,7 @@ */ #include "codegen_c_host.h" +#include #include #include @@ -404,6 +405,9 @@ runtime::Module BuildCHost(IRModule mod, Target target) { return CSourceModuleCreate(code, "c", cg.GetFunctionNames()); } -TVM_FFI_REGISTER_GLOBAL("target.build.c").set_body_typed(BuildCHost); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("target.build.c", BuildCHost); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 0f87a16c449b..6704180de161 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -22,6 +22,7 @@ */ #include "codegen_metal.h" +#include #include #include @@ -466,6 +467,9 @@ runtime::Module BuildMetal(IRModule mod, Target target) { return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str()); } -TVM_FFI_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("target.build.metal", BuildMetal); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index b94dc17bff33..8ba0dad6e405 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -22,6 +22,8 @@ */ #include "codegen_opencl.h" +#include + #include #include #include @@ -672,7 +674,10 @@ runtime::Module BuildOpenCL(IRModule mod, Target target) { return OpenCLModuleCreate(code.str(), "cl", ExtractFuncInfo(mod), code.str()); } -TVM_FFI_REGISTER_GLOBAL("target.build.opencl").set_body_typed(BuildOpenCL); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("target.build.opencl", BuildOpenCL); +}); String DeviceScopeCompatibilityFromTarget(Target target, String memory_scope) { auto prototype_keys = target->GetKeys(); @@ -684,8 +689,10 @@ String DeviceScopeCompatibilityFromTarget(Target target, String memory_scope) { return memory_scope; } -TVM_FFI_REGISTER_GLOBAL("DeviceScopeCompatibility.opencl") - .set_body_typed(DeviceScopeCompatibilityFromTarget); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("DeviceScopeCompatibility.opencl", DeviceScopeCompatibilityFromTarget); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 995eddee027e..b579663e8c3e 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -23,6 +23,7 @@ #include "codegen_webgpu.h" #include +#include #include #include @@ -779,8 +780,10 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) { return runtime::Module(n); } -TVM_FFI_REGISTER_GLOBAL("target.build.webgpu").set_body_typed([](IRModule mod, Target target) { - return BuildWebGPU(mod, target); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("target.build.webgpu", + [](IRModule mod, Target target) { return BuildWebGPU(mod, target); }); }); } // namespace codegen diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 5e1f132fb5a5..30263a908ffc 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -174,8 +175,10 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, return runtime::Module(n); } -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_c") - .set_body_typed(CSourceModuleNode::LoadFromBinary); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("runtime.module.loadbinary_c", CSourceModuleNode::LoadFromBinary); +}); /*! * \brief A concrete class to get access to base methods of CodegenSourceBase. @@ -248,13 +251,16 @@ runtime::Module DeviceSourceModuleCreate( return runtime::Module(n); } -TVM_FFI_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate); - -TVM_FFI_REGISTER_GLOBAL("runtime.CSourceModuleCreate") - .set_body_typed([](String code, String fmt, Optional> func_names, - Optional> const_vars) { - return CSourceModuleCreate(code, fmt, func_names.value_or({}), const_vars.value_or({})); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.SourceModuleCreate", SourceModuleCreate) + .def("runtime.CSourceModuleCreate", [](String code, String fmt, + Optional> func_names, + Optional> const_vars) { + return CSourceModuleCreate(code, fmt, func_names.value_or({}), const_vars.value_or({})); + }); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index f3dbd624ec00..fb3c592cca57 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -22,6 +22,8 @@ * \brief Build SPIRV block */ +#include + #include "../../runtime/spirv/spirv_shader.h" #include "../../runtime/vulkan/vulkan_module.h" #include "../build_common.h" @@ -35,8 +37,10 @@ runtime::Module BuildSPIRV(IRModule mod, Target target) { return runtime::VulkanModuleCreate(smap, ExtractFuncInfo(mod), spirv_text); } -TVM_FFI_REGISTER_GLOBAL("target.build.vulkan").set_body_typed([](IRModule mod, Target target) { - return BuildSPIRV(mod, target); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("target.build.vulkan", + [](IRModule mod, Target target) { return BuildSPIRV(mod, target); }); }); } // namespace codegen diff --git a/src/target/tag.cc b/src/target/tag.cc index 04f0a146034f..108fed9cd0c4 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -35,8 +36,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ TargetTagNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(TargetTagNode); -TVM_FFI_REGISTER_GLOBAL("target.TargetTagListTags").set_body_typed(TargetTag::ListTags); -TVM_FFI_REGISTER_GLOBAL("target.TargetTagAddTag").set_body_typed(TargetTag::AddTag); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("target.TargetTagListTags", TargetTag::ListTags) + .def("target.TargetTagAddTag", TargetTag::AddTag); +}); /********** Registry-related code **********/ diff --git a/src/target/target.cc b/src/target/target.cc index c73918b9d125..34f2f3118340 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include #include @@ -1010,23 +1011,25 @@ std::unordered_map TargetInternal::QueryDevice(int device_id, /********** Registry **********/ -TVM_FFI_REGISTER_GLOBAL("target.Target").set_body_packed(TargetInternal::ConstructorDispatcher); -TVM_FFI_REGISTER_GLOBAL("target.TargetEnterScope").set_body_typed(TargetInternal::EnterScope); -TVM_FFI_REGISTER_GLOBAL("target.TargetExitScope").set_body_typed(TargetInternal::ExitScope); -TVM_FFI_REGISTER_GLOBAL("target.TargetCurrent").set_body_typed(Target::Current); -TVM_FFI_REGISTER_GLOBAL("target.TargetExport").set_body_typed(TargetInternal::Export); -TVM_FFI_REGISTER_GLOBAL("target.WithHost").set_body_typed(TargetInternal::WithHost); -TVM_FFI_REGISTER_GLOBAL("target.TargetGetDeviceType").set_body_typed([](const Target& target) { - return target->GetTargetDeviceType(); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("target.Target", TargetInternal::ConstructorDispatcher) + .def("target.TargetEnterScope", TargetInternal::EnterScope) + .def("target.TargetExitScope", TargetInternal::ExitScope) + .def("target.TargetCurrent", Target::Current) + .def("target.TargetExport", TargetInternal::Export) + .def("target.WithHost", TargetInternal::WithHost) + .def("target.TargetGetDeviceType", + [](const Target& target) { return target->GetTargetDeviceType(); }) + .def("target.TargetGetFeature", [](const Target& target, const String& feature_key) -> Any { + if (auto opt_any = target->GetFeature(feature_key)) { + return opt_any.value(); + } else { + return Any(); + } + }); }); -TVM_FFI_REGISTER_GLOBAL("target.TargetGetFeature") - .set_body_typed([](const Target& target, const String& feature_key) -> Any { - if (auto opt_any = target->GetFeature(feature_key)) { - return opt_any.value(); - } else { - return Any(); - } - }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index ae35c7d97f12..012d1894623d 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -22,6 +22,7 @@ * \brief Target kind registry */ #include +#include #include #include #include @@ -448,23 +449,24 @@ TVM_REGISTER_TARGET_KIND("test", kDLCPU) // line break /********** Registry **********/ -TVM_FFI_REGISTER_GLOBAL("target.TargetKindGetAttr") - .set_body_typed([](TargetKind kind, String attr_name) -> ffi::Any { - auto target_attr_map = TargetKind::GetAttrMap(attr_name); - ffi::Any rv; - if (target_attr_map.count(kind)) { - rv = target_attr_map[kind]; - } - return rv; - }); -TVM_FFI_REGISTER_GLOBAL("target.ListTargetKinds") - .set_body_typed(TargetKindRegEntry::ListTargetKinds); -TVM_FFI_REGISTER_GLOBAL("target.ListTargetKindOptions") - .set_body_typed(TargetKindRegEntry::ListTargetKindOptions); -TVM_FFI_REGISTER_GLOBAL("target.ListTargetKindOptionsFromName") - .set_body_typed([](String target_kind_name) { - TargetKind kind = TargetKind::Get(target_kind_name).value(); - return TargetKindRegEntry::ListTargetKindOptions(kind); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("target.TargetKindGetAttr", + [](TargetKind kind, String attr_name) -> ffi::Any { + auto target_attr_map = TargetKind::GetAttrMap(attr_name); + ffi::Any rv; + if (target_attr_map.count(kind)) { + rv = target_attr_map[kind]; + } + return rv; + }) + .def("target.ListTargetKinds", TargetKindRegEntry::ListTargetKinds) + .def("target.ListTargetKindOptions", TargetKindRegEntry::ListTargetKindOptions) + .def("target.ListTargetKindOptionsFromName", [](String target_kind_name) { + TargetKind kind = TargetKind::Get(target_kind_name).value(); + return TargetKindRegEntry::ListTargetKindOptions(kind); + }); +}); } // namespace tvm diff --git a/src/target/virtual_device.cc b/src/target/virtual_device.cc index da1ed4461baf..9a03f8012e71 100644 --- a/src/target/virtual_device.cc +++ b/src/target/virtual_device.cc @@ -22,6 +22,7 @@ * \brief A compile time representation for where data is to be stored at runtime, and how to * compile code to compute it. */ +#include #include #include #include @@ -193,7 +194,10 @@ VirtualDevice VirtualDeviceCache::Unique(const VirtualDevice& virtual_device) { virtual_device->target, virtual_device->memory_scope); } -TVM_FFI_REGISTER_GLOBAL("target.VirtualDevice_ForDeviceTargetAndMemoryScope") - .set_body_typed(VirtualDevice::ForDeviceTargetAndMemoryScope); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("target.VirtualDevice_ForDeviceTargetAndMemoryScope", + VirtualDevice::ForDeviceTargetAndMemoryScope); +}); } // namespace tvm diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index c626fe6aa17b..00902de7ab62 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -154,11 +155,14 @@ ComputeOp::ComputeOp(std::string name, std::string tag, Map at data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("te.ComputeOp") - .set_body_typed([](std::string name, std::string tag, Optional> attrs, - Array axis, Array body) { - return ComputeOp(name, tag, attrs.value_or({}), axis, body); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("te.ComputeOp", + [](std::string name, std::string tag, Optional> attrs, + Array axis, Array body) { + return ComputeOp(name, tag, attrs.value_or({}), axis, body); + }); +}); // The schedule related logics Array ComputeOpNode::InputTensors() const { diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 0e90984e28b7..0d30ca17443d 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -784,16 +785,18 @@ PrimFunc CreatePrimFunc(const Array& arg_list, return CreatePrimFuncWithConstants(arg_list, {}, index_dtype_override); } -TVM_FFI_REGISTER_GLOBAL("te.CreatePrimFunc") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - Array arg_list = args[0].cast>(); - std::optional index_dtype_override{std::nullopt}; - // Add conversion to make std::optional compatible with FFI. - if (args[1] != nullptr) { - index_dtype_override = args[1].cast(); - } - *ret = CreatePrimFunc(arg_list, index_dtype_override); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("te.CreatePrimFunc", [](ffi::PackedArgs args, ffi::Any* ret) { + Array arg_list = args[0].cast>(); + std::optional index_dtype_override{std::nullopt}; + // Add conversion to make std::optional compatible with FFI. + if (args[1] != nullptr) { + index_dtype_override = args[1].cast(); + } + *ret = CreatePrimFunc(arg_list, index_dtype_override); + }); +}); // Relax version impl PrimFunc GenerateAndCompletePrimFunc(const Array& arg_tir_var_list, diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index c7283d847691..c35e79ac2d3e 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include @@ -73,13 +74,16 @@ ExternOp::ExternOp(std::string name, std::string tag, Map attr data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("te.ExternOp") - .set_body_typed([](std::string name, std::string tag, Optional> attrs, - Array inputs, Array input_placeholders, - Array output_placeholders, Stmt body) { - return ExternOp(name, tag, attrs.value_or({}), inputs, input_placeholders, - output_placeholders, body); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("te.ExternOp", + [](std::string name, std::string tag, Optional> attrs, + Array inputs, Array input_placeholders, + Array output_placeholders, Stmt body) { + return ExternOp(name, tag, attrs.value_or({}), inputs, input_placeholders, + output_placeholders, body); + }); +}); Array ExternOpNode::InputTensors() const { return inputs; } diff --git a/src/te/operation/graph.cc b/src/te/operation/graph.cc index e2bbced85f89..6622f35d70ea 100644 --- a/src/te/operation/graph.cc +++ b/src/te/operation/graph.cc @@ -24,6 +24,7 @@ #include "graph.h" #include +#include #include #include #include @@ -80,12 +81,14 @@ Array PostDFSOrder(const Array& roots, const ReadGraph& g) return post_order; } -TVM_FFI_REGISTER_GLOBAL("schedule.CreateReadGraph").set_body_typed(CreateReadGraph); - -TVM_FFI_REGISTER_GLOBAL("schedule.PostDFSOrder") - .set_body_typed([](const Array& roots, const ReadGraph& g) { - return PostDFSOrder(roots, g); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("schedule.CreateReadGraph", CreateReadGraph) + .def("schedule.PostDFSOrder", [](const Array& roots, const ReadGraph& g) { + return PostDFSOrder(roots, g); + }); +}); } // namespace te } // namespace tvm diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index 2e826c836ed6..05acaac0cb6d 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include namespace tvm { @@ -63,20 +64,22 @@ Tensor placeholder(Array shape, DataType dtype, std::string name) { return PlaceholderOp(name, shape, dtype).output(0); } -TVM_FFI_REGISTER_GLOBAL("te.Placeholder") - .set_body_typed([](Variant> shape_arg, DataType dtype, - std::string name) { - auto shape = [&]() -> Array { - if (auto arg_expr = shape_arg.as()) { - return {arg_expr.value()}; - } else if (auto arg_array = shape_arg.as>()) { - return arg_array.value(); - } else { - LOG(FATAL) << "Variant did not contain either allowed type"; - } - }(); - return placeholder(shape, dtype, name); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("te.Placeholder", [](Variant> shape_arg, + DataType dtype, std::string name) { + auto shape = [&]() -> Array { + if (auto arg_expr = shape_arg.as()) { + return {arg_expr.value()}; + } else if (auto arg_array = shape_arg.as>()) { + return arg_array.value(); + } else { + LOG(FATAL) << "Variant did not contain either allowed type"; + } + }(); + return placeholder(shape, dtype, name); + }); +}); Array PlaceholderOpNode::InputTensors() const { return {}; } diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index c4d56a7f15c2..a5adfcc26939 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -22,6 +22,7 @@ * \file scan_op.cc */ #include +#include #include #include @@ -99,12 +100,15 @@ ScanOp::ScanOp(std::string name, std::string tag, Optional data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("te.ScanOp") - .set_body_typed([](std::string name, std::string tag, Optional> attrs, - IterVar axis, Array init, Array update, - Array state_placeholder, Array inputs) { - return ScanOp(name, tag, attrs, axis, init, update, state_placeholder, inputs); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "te.ScanOp", [](std::string name, std::string tag, Optional> attrs, + IterVar axis, Array init, Array update, + Array state_placeholder, Array inputs) { + return ScanOp(name, tag, attrs, axis, init, update, state_placeholder, inputs); + }); +}); Array scan(Array init, Array update, Array state_placeholder, Array inputs, std::string name, std::string tag, diff --git a/src/te/tensor.cc b/src/te/tensor.cc index cb3cb593d751..df204483671c 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -21,6 +21,7 @@ * \file tensor.cc */ #include +#include #include #include @@ -109,10 +110,13 @@ Tensor::Tensor(Array shape, DataType dtype, Operation op, int value_in data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("te.Tensor") - .set_body_typed([](Array shape, DataType dtype, Operation op, int value_index) { - return Tensor(shape, dtype, op, value_index); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("te.Tensor", + [](Array shape, DataType dtype, Operation op, int value_index) { + return Tensor(shape, dtype, op, value_index); + }); +}); TVM_REGISTER_NODE_TYPE(TensorNode); @@ -123,19 +127,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Other tensor ops. -TVM_FFI_REGISTER_GLOBAL("te.TensorEqual").set_body_method(&Tensor::operator==); - -TVM_FFI_REGISTER_GLOBAL("te.TensorHash").set_body_typed([](Tensor tensor) -> int64_t { - return static_cast(std::hash()(tensor)); -}); - -TVM_FFI_REGISTER_GLOBAL("te.OpGetOutput").set_body_typed([](Operation op, int64_t output) { - return op.output(static_cast(output)); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("te.TensorEqual", &Tensor::operator==) + .def("te.TensorHash", + [](Tensor tensor) -> int64_t { + return static_cast(std::hash()(tensor)); + }) + .def("te.OpGetOutput", + [](Operation op, int64_t output) { return op.output(static_cast(output)); }) + .def_method("te.OpNumOutputs", &OperationNode::num_outputs) + .def_method("te.OpInputTensors", &OperationNode::InputTensors); }); -TVM_FFI_REGISTER_GLOBAL("te.OpNumOutputs").set_body_method(&OperationNode::num_outputs); - -TVM_FFI_REGISTER_GLOBAL("te.OpInputTensors").set_body_method(&OperationNode::InputTensors); - } // namespace te } // namespace tvm diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 8b3598e3563d..be3d1e7f29a6 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include @@ -410,9 +411,12 @@ Array> GetBlockReadWriteRegion(const Block& block, return {reads, writes}; } -TVM_FFI_REGISTER_GLOBAL("tir.analysis.GetBlockAccessRegion").set_body_typed(GetBlockAccessRegion); -TVM_FFI_REGISTER_GLOBAL("tir.analysis.GetBlockReadWriteRegion") - .set_body_typed(GetBlockReadWriteRegion); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.analysis.GetBlockAccessRegion", GetBlockAccessRegion) + .def("tir.analysis.GetBlockReadWriteRegion", GetBlockReadWriteRegion); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index 7ac84ce894a4..b90974798bcd 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -22,6 +22,7 @@ * \brief Detect the lowest common ancestor(LCA) of buffer access */ +#include #include #include @@ -346,7 +347,9 @@ Map> DetectBufferAccessLCA(const PrimFunc& func) { return LCADetector::Detect(func); } -TVM_FFI_REGISTER_GLOBAL("tir.analysis.detect_buffer_access_lca") - .set_body_typed(DetectBufferAccessLCA); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.analysis.detect_buffer_access_lca", DetectBufferAccessLCA); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/calculate_allocated_memory.cc b/src/tir/analysis/calculate_allocated_memory.cc index de208ce9c1e0..b1427fc99ccf 100644 --- a/src/tir/analysis/calculate_allocated_memory.cc +++ b/src/tir/analysis/calculate_allocated_memory.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -96,18 +97,22 @@ tvm::Map > CalculateAllocatedBytes(const IRMod return results; } -TVM_FFI_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes") - .set_body_typed([](ObjectRef obj) -> tvm::Map > { - if (auto func = obj.as()) { - return CalculateAllocatedBytes(func.value()); - } else if (auto mod = obj.as()) { - return CalculateAllocatedBytes(mod.value()); - } else { - LOG(FATAL) << "TypeError: Expect the input to be either PrimFunc or IRModule, but gets: " - << obj->GetTypeKey(); - throw; - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "tir.analysis.calculate_allocated_bytes", + [](ObjectRef obj) -> tvm::Map > { + if (auto func = obj.as()) { + return CalculateAllocatedBytes(func.value()); + } else if (auto mod = obj.as()) { + return CalculateAllocatedBytes(mod.value()); + } else { + LOG(FATAL) << "TypeError: Expect the input to be either PrimFunc or IRModule, but gets: " + << obj->GetTypeKey(); + throw; + } + }); +}); bool VerifyVTCMLimit(const IRModule& mod, Integer limit) { auto all_sizes = CalculateAllocatedBytes(mod); @@ -155,8 +160,10 @@ Array GetVTCMCompactionPasses() { return pass_list; } -TVM_FFI_REGISTER_GLOBAL("tir.analysis.get_vtcm_compaction_passes").set_body_typed([]() { - return GetVTCMCompactionPasses(); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.analysis.get_vtcm_compaction_passes", + []() { return GetVTCMCompactionPasses(); }); }); namespace transform { @@ -191,7 +198,10 @@ Pass VerifyVTCMLimit(Optional default_target) { return tvm::transform::CreateModulePass(pass_func, 0, "tir.calculate_allocated_bytes", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.VerifyVTCMLimit").set_body_typed(VerifyVTCMLimit); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.VerifyVTCMLimit", VerifyVTCMLimit); +}); } // namespace transform } // namespace tir diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc index 07d6500570f8..52c1adea4662 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tir/analysis/deep_equal.cc @@ -22,6 +22,7 @@ * \brief Deep equality checking. */ #include +#include #include #include #include @@ -68,10 +69,12 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false, std::nullopt); } -TVM_FFI_REGISTER_GLOBAL("tir.analysis.expr_deep_equal") - .set_body_typed([](const PrimExpr& lhs, const PrimExpr& rhs) { - return ExprDeepEqual()(lhs, rhs); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "tir.analysis.expr_deep_equal", + [](const PrimExpr& lhs, const PrimExpr& rhs) { return ExprDeepEqual()(lhs, rhs); }); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc index df41e7da1807..8ae05ad53b8e 100644 --- a/src/tir/analysis/estimate_flops.cc +++ b/src/tir/analysis/estimate_flops.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include @@ -246,18 +247,20 @@ double EstimateTIRFlops(const IRModule& mod) { return PostprocessResults(result) + cached_result; } -TVM_FFI_REGISTER_GLOBAL("tir.analysis.EstimateTIRFlops") - .set_body_typed([](ObjectRef obj) -> double { - if (auto mod = obj.as()) { - return EstimateTIRFlops(mod.value()); - } else if (auto stmt = obj.as()) { - return EstimateTIRFlops(stmt.value()); - } else { - LOG(FATAL) << "TypeError: Expect the input to be either IRModule or Stmt, but gets: " - << obj->GetTypeKey(); - throw; - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.analysis.EstimateTIRFlops", [](ObjectRef obj) -> double { + if (auto mod = obj.as()) { + return EstimateTIRFlops(mod.value()); + } else if (auto stmt = obj.as()) { + return EstimateTIRFlops(stmt.value()); + } else { + LOG(FATAL) << "TypeError: Expect the input to be either IRModule or Stmt, but gets: " + << obj->GetTypeKey(); + throw; + } + }); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/identify_memcpy.cc b/src/tir/analysis/identify_memcpy.cc index dcffe1c1d6b8..d27a4d829656 100644 --- a/src/tir/analysis/identify_memcpy.cc +++ b/src/tir/analysis/identify_memcpy.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -282,34 +283,37 @@ std::optional IdentifyMemCpy(const For& loop, arith::Analyzer* an } // Expose the IdentifyMemCpy functionality to Python API for purpose of unit testing. -TVM_FFI_REGISTER_GLOBAL("tir.analysis._identify_memcpy").set_body_typed([](const Stmt& stmt) { - Array output; - - struct Visitor : arith::IRVisitorWithAnalyzer { - explicit Visitor(Array* output) : output(output) {} - Array* output; - - private: - using IRVisitorWithAnalyzer::VisitStmt_; - void VisitStmt_(const ForNode* op) override { - For loop = GetRef(op); - auto result = IdentifyMemCpyImpl(loop, &(Visitor::analyzer_)); - if (auto* ptr = std::get_if(&result)) { - output->push_back(Array{ptr->source, ptr->dest}); - } else if (auto* ptr = std::get_if(&result)) { - output->push_back(StringImm(*ptr)); - } else { - LOG(FATAL) << "Internal error, unhandled std::variant type"; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.analysis._identify_memcpy", [](const Stmt& stmt) { + Array output; + + struct Visitor : arith::IRVisitorWithAnalyzer { + explicit Visitor(Array* output) : output(output) {} + Array* output; + + private: + using IRVisitorWithAnalyzer::VisitStmt_; + void VisitStmt_(const ForNode* op) override { + For loop = GetRef(op); + auto result = IdentifyMemCpyImpl(loop, &(Visitor::analyzer_)); + if (auto* ptr = std::get_if(&result)) { + output->push_back(Array{ptr->source, ptr->dest}); + } else if (auto* ptr = std::get_if(&result)) { + output->push_back(StringImm(*ptr)); + } else { + LOG(FATAL) << "Internal error, unhandled std::variant type"; + } + + IRVisitorWithAnalyzer::VisitStmt_(op); } + }; - IRVisitorWithAnalyzer::VisitStmt_(op); - } - }; - - Visitor visitor(&output); - visitor(stmt); + Visitor visitor(&output); + visitor(stmt); - return output; + return output; + }); }); } // namespace tir diff --git a/src/tir/analysis/is_pure_function.cc b/src/tir/analysis/is_pure_function.cc index 4af823604971..0047902adf51 100644 --- a/src/tir/analysis/is_pure_function.cc +++ b/src/tir/analysis/is_pure_function.cc @@ -21,6 +21,7 @@ * \file is_pure_function.cc * \brief PrimFunc purity analysis */ +#include #include #include #include @@ -91,7 +92,10 @@ bool IsPureFunction(const PrimFunc& func, bool assert_on_error) { return PurityChecker::Check(func, assert_on_error); } -TVM_FFI_REGISTER_GLOBAL("tir.analysis.is_pure_function").set_body_typed(IsPureFunction); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.analysis.is_pure_function", IsPureFunction); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/oob_checker.cc b/src/tir/analysis/oob_checker.cc index 898a92adc7db..16eca9921090 100644 --- a/src/tir/analysis/oob_checker.cc +++ b/src/tir/analysis/oob_checker.cc @@ -21,6 +21,7 @@ * Out of bounds array access static analyzer. */ +#include #include #include "../../arith/ir_visitor_with_analyzer.h" @@ -123,7 +124,10 @@ transform::Pass OOBChecker() { return transform::CreatePrimFuncPass(pass_func, 0, "tir.analysis.OOBChecker", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.analysis.OOBChecker").set_body_typed(OOBChecker); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.analysis.OOBChecker", OOBChecker); +}); } // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/stmt_finding.cc b/src/tir/analysis/stmt_finding.cc index b5a23e35d276..97088000d4ca 100644 --- a/src/tir/analysis/stmt_finding.cc +++ b/src/tir/analysis/stmt_finding.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include @@ -139,12 +140,15 @@ const BlockNode* FindAnchorBlock(const IRModule& mod) { return nullptr; } -TVM_FFI_REGISTER_GLOBAL("tir.analysis.find_anchor_block").set_body_typed([](const IRModule& mod) { - auto ret = FindAnchorBlock(mod); - if (ret) { - return Optional(GetRef(ret)); - } - return Optional(std::nullopt); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.analysis.find_anchor_block", [](const IRModule& mod) { + auto ret = FindAnchorBlock(mod); + if (ret) { + return Optional(GetRef(ret)); + } + return Optional(std::nullopt); + }); }); } // namespace tir diff --git a/src/tir/analysis/var_use_def_analysis.cc b/src/tir/analysis/var_use_def_analysis.cc index 0d75cebac798..56f5856f2ee9 100644 --- a/src/tir/analysis/var_use_def_analysis.cc +++ b/src/tir/analysis/var_use_def_analysis.cc @@ -22,6 +22,8 @@ * \brief Classes and functions to analyze var defition and usage. */ #include "var_use_def_analysis.h" + +#include namespace tvm { namespace tir { @@ -199,15 +201,18 @@ Array UndefinedVars(const PrimExpr& expr, const Array& args) { return m.undefined_; } -TVM_FFI_REGISTER_GLOBAL("tir.analysis.UndefinedVars") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - if (auto opt_stmt = args[0].as()) { - *rv = UndefinedVars(opt_stmt.value(), args[1].cast>()); - } else if (auto opt_expr = args[0].as()) { - *rv = UndefinedVars(opt_expr.value(), args[1].cast>()); - } else { - LOG(FATAL) << "either UndefinedVars(stmt, args) or UndefinedVars(expr, args) is expected"; - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed( + "tir.analysis.UndefinedVars", [](ffi::PackedArgs args, ffi::Any* rv) { + if (auto opt_stmt = args[0].as()) { + *rv = UndefinedVars(opt_stmt.value(), args[1].cast>()); + } else if (auto opt_expr = args[0].as()) { + *rv = UndefinedVars(opt_expr.value(), args[1].cast>()); + } else { + LOG(FATAL) << "either UndefinedVars(stmt, args) or UndefinedVars(expr, args) is expected"; + } + }); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index ef46a41687ad..f126dc0fe061 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -25,6 +25,7 @@ */ #include +#include #include #include #include @@ -321,7 +322,10 @@ bool VerifyGPUCode(const PrimFunc& func, Map constraints) { return errs.size() == 0; } -TVM_FFI_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.analysis.verify_gpu_code", VerifyGPUCode); +}); namespace transform { @@ -346,7 +350,10 @@ Pass VerifyGPUCode(Map constraints) { return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyGPUCode", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.VerifyGPUCode").set_body_typed(VerifyGPUCode); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.VerifyGPUCode", VerifyGPUCode); +}); } // namespace transform } // namespace tir diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index bc567879c22b..35caebf40f11 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -22,6 +22,7 @@ * \brief Pass to check if memory accesses are legal. */ #include +#include #include #include #include @@ -186,7 +187,10 @@ std::vector VerifyMemory_(const PrimFunc& func) { bool VerifyMemory(const PrimFunc& func) { return VerifyMemory_(func).size() == 0; } -TVM_FFI_REGISTER_GLOBAL("tir.analysis.verify_memory").set_body_typed(VerifyMemory); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.analysis.verify_memory", VerifyMemory); +}); namespace transform { @@ -211,7 +215,10 @@ Pass VerifyMemory() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyMemory", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.VerifyMemory").set_body_typed(VerifyMemory); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.VerifyMemory", VerifyMemory); +}); } // namespace transform } // namespace tir diff --git a/src/tir/analysis/verify_ssa.cc b/src/tir/analysis/verify_ssa.cc index 33abb39c367f..f0580b622a7e 100644 --- a/src/tir/analysis/verify_ssa.cc +++ b/src/tir/analysis/verify_ssa.cc @@ -24,6 +24,7 @@ * \file verify_ssa.cc */ #include +#include #include #include #include @@ -139,7 +140,10 @@ bool VerifySSA(const PrimFunc& func) { return visitor.is_ssa_; } -TVM_FFI_REGISTER_GLOBAL("tir.analysis.verify_ssa").set_body_typed(VerifySSA); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.analysis.verify_ssa", VerifySSA); +}); namespace transform { @@ -155,7 +159,10 @@ Pass VerifySSA() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifySSA", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.VerifySSA").set_body_typed(VerifySSA); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.VerifySSA", VerifySSA); +}); } // namespace transform diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc index a0c5f4829bf8..921d12cde714 100644 --- a/src/tir/analysis/verify_well_formed.cc +++ b/src/tir/analysis/verify_well_formed.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include @@ -368,17 +369,20 @@ bool VerifyWellFormed(const IRModule& mod, bool assert_mode) { return true; } -TVM_FFI_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed") - .set_body_typed([](const ObjectRef& obj, bool assert_mode) { - if (auto opt = obj.as()) { - return VerifyWellFormed(opt.value(), assert_mode); - } else if (auto opt = obj.as()) { - return VerifyWellFormed(opt.value(), assert_mode); - } else { - LOG(FATAL) << "Expected VerifyWellFormed argument to be a PrimFunc or IRModule, but found " - << obj->GetTypeKey(); - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.analysis.VerifyWellFormed", [](const ObjectRef& obj, + bool assert_mode) { + if (auto opt = obj.as()) { + return VerifyWellFormed(opt.value(), assert_mode); + } else if (auto opt = obj.as()) { + return VerifyWellFormed(opt.value(), assert_mode); + } else { + LOG(FATAL) << "Expected VerifyWellFormed argument to be a PrimFunc or IRModule, but found " + << obj->GetTypeKey(); + } + }); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/block_dependence_info.cc b/src/tir/ir/block_dependence_info.cc index 9267ecec75a3..c4c71a12c524 100644 --- a/src/tir/ir/block_dependence_info.cc +++ b/src/tir/ir/block_dependence_info.cc @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include @@ -87,15 +88,18 @@ BlockDependenceInfo::BlockDependenceInfo(IRModule mod) { } TVM_REGISTER_NODE_TYPE(BlockDependenceInfoNode); -TVM_FFI_REGISTER_GLOBAL("tir.BlockDependenceInfo") - .set_body_typed([](IRModule mod) -> BlockDependenceInfo { return BlockDependenceInfo(mod); }); -TVM_FFI_REGISTER_GLOBAL("tir.BlockDependenceInfoGetBlockScope") - .set_body_method(&BlockDependenceInfoNode::GetBlockScope); -TVM_FFI_REGISTER_GLOBAL("tir.BlockDependenceInfoGetSRef") - .set_body_typed([](BlockDependenceInfo self, Stmt stmt) -> Optional { - auto it = self->stmt2ref.find(stmt.get()); - return it != self->stmt2ref.end() ? it->second : Optional(std::nullopt); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.BlockDependenceInfo", + [](IRModule mod) -> BlockDependenceInfo { return BlockDependenceInfo(mod); }) + .def_method("tir.BlockDependenceInfoGetBlockScope", &BlockDependenceInfoNode::GetBlockScope) + .def("tir.BlockDependenceInfoGetSRef", + [](BlockDependenceInfo self, Stmt stmt) -> Optional { + auto it = self->stmt2ref.find(stmt.get()); + return it != self->stmt2ref.end() ? it->second : Optional(std::nullopt); + }); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/block_scope.cc b/src/tir/ir/block_scope.cc index 70d35aa9e259..4cddd54d428e 100644 --- a/src/tir/ir/block_scope.cc +++ b/src/tir/ir/block_scope.cc @@ -197,21 +197,20 @@ TVM_REGISTER_NODE_TYPE(StmtSRefNode); TVM_REGISTER_NODE_TYPE(DependencyNode); TVM_REGISTER_NODE_TYPE(BlockScopeNode); -TVM_FFI_REGISTER_GLOBAL("tir.StmtSRefStmt").set_body_typed([](StmtSRef sref) -> Optional { - return GetRef>(sref->stmt); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.StmtSRefStmt", + [](StmtSRef sref) -> Optional { return GetRef>(sref->stmt); }) + .def("tir.StmtSRefParent", + [](StmtSRef sref) -> Optional { + return GetRef>(sref->parent); + }) + .def("tir.StmtSRefRootMark", StmtSRef::RootMark) + .def("tir.StmtSRefInlineMark", StmtSRef::InlineMark) + .def_method("tir.BlockScopeGetDepsBySrc", &BlockScopeNode::GetDepsBySrc) + .def_method("tir.BlockScopeGetDepsByDst", &BlockScopeNode::GetDepsByDst); }); -TVM_FFI_REGISTER_GLOBAL("tir.StmtSRefParent") - .set_body_typed([](StmtSRef sref) -> Optional { - return GetRef>(sref->parent); - }); -TVM_FFI_REGISTER_GLOBAL("tir.StmtSRefRootMark") // - .set_body_typed(StmtSRef::RootMark); -TVM_FFI_REGISTER_GLOBAL("tir.StmtSRefInlineMark") // - .set_body_typed(StmtSRef::InlineMark); -TVM_FFI_REGISTER_GLOBAL("tir.BlockScopeGetDepsBySrc") - .set_body_method(&BlockScopeNode::GetDepsBySrc); -TVM_FFI_REGISTER_GLOBAL("tir.BlockScopeGetDepsByDst") - .set_body_method(&BlockScopeNode::GetDepsByDst); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 8fcd909bb2fc..48e69929647a 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include #include @@ -642,36 +643,34 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std TVM_REGISTER_NODE_TYPE(BufferNode); -TVM_FFI_REGISTER_GLOBAL("tir.Buffer").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - ICHECK_EQ(args.size(), 11); - auto buffer_type = args[8].cast(); - BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; - auto data = args[0].cast(); - auto dtype = args[1].cast(); - auto shape = args[2].cast>(); - auto strides = args[3].cast>(); - auto elem_offset = args[4].cast(); - auto name = args[5].cast(); - auto data_alignment = args[6].cast(); - auto offset_factor = args[7].cast(); - auto axis_separators = args[9].cast>(); - auto span = args[10].cast(); - *ret = Buffer(data, dtype, shape, strides, elem_offset, name, data_alignment, offset_factor, type, - axis_separators, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("tir.Buffer", + [](ffi::PackedArgs args, ffi::Any* ret) { + ICHECK_EQ(args.size(), 11); + auto buffer_type = args[8].cast(); + BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; + auto data = args[0].cast(); + auto dtype = args[1].cast(); + auto shape = args[2].cast>(); + auto strides = args[3].cast>(); + auto elem_offset = args[4].cast(); + auto name = args[5].cast(); + auto data_alignment = args[6].cast(); + auto offset_factor = args[7].cast(); + auto axis_separators = args[9].cast>(); + auto span = args[10].cast(); + *ret = Buffer(data, dtype, shape, strides, elem_offset, name, data_alignment, + offset_factor, type, axis_separators, span); + }) + .def_method("tir.BufferAccessPtr", &Buffer::access_ptr) + .def_method("tir.BufferGetFlattenedBuffer", &Buffer::GetFlattenedBuffer) + .def_method("tir.BufferOffsetOf", &Buffer::OffsetOf) + .def_method("tir.BufferVLoad", &Buffer::vload) + .def_method("tir.BufferVStore", &Buffer::vstore) + .def_method("tir.BufferStorageScope", &Buffer::scope); }); -TVM_FFI_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr); - -TVM_FFI_REGISTER_GLOBAL("tir.BufferGetFlattenedBuffer") - .set_body_method(&Buffer::GetFlattenedBuffer); - -TVM_FFI_REGISTER_GLOBAL("tir.BufferOffsetOf").set_body_method(&Buffer::OffsetOf); - -TVM_FFI_REGISTER_GLOBAL("tir.BufferVLoad").set_body_method(&Buffer::vload); - -TVM_FFI_REGISTER_GLOBAL("tir.BufferVStore").set_body_method(&Buffer::vstore); - -TVM_FFI_REGISTER_GLOBAL("tir.BufferStorageScope").set_body_method(&Buffer::scope); - } // namespace tir } // namespace tvm diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index fbb1901f087e..4844dab26284 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include @@ -432,45 +433,32 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ")"; }); -TVM_FFI_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name, DataType dtype) { - return Layout(name, dtype); -}); - -TVM_FFI_REGISTER_GLOBAL("tir.LayoutIndexOf") - .set_body_typed([](Layout layout, std::string axis) -> int { - return layout.IndexOf(LayoutAxis::Get(axis)); - }); - -TVM_FFI_REGISTER_GLOBAL("tir.LayoutFactorOf") - .set_body_typed([](Layout layout, std::string axis) -> int { - return layout.FactorOf(LayoutAxis::Get(axis)); - }); - -TVM_FFI_REGISTER_GLOBAL("tir.LayoutNdim").set_body_typed([](Layout layout) -> int { - return layout.ndim(); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.Layout", [](std::string name, DataType dtype) { return Layout(name, dtype); }) + .def("tir.LayoutIndexOf", + [](Layout layout, std::string axis) -> int { + return layout.IndexOf(LayoutAxis::Get(axis)); + }) + .def("tir.LayoutFactorOf", + [](Layout layout, std::string axis) -> int { + return layout.FactorOf(LayoutAxis::Get(axis)); + }) + .def("tir.LayoutNdim", [](Layout layout) -> int { return layout.ndim(); }) + .def("tir.LayoutGetItem", + [](Layout layout, int idx) -> std::string { + const LayoutAxis& axis = layout[idx]; + return axis.name(); + }) + .def("tir.BijectiveLayout", + [](Layout src_layout, Layout dst_layout) -> BijectiveLayout { + return BijectiveLayout(src_layout, dst_layout); + }) + .def_method("tir.BijectiveLayoutForwardIndex", &BijectiveLayout::ForwardIndex) + .def_method("tir.BijectiveLayoutBackwardIndex", &BijectiveLayout::BackwardIndex) + .def_method("tir.BijectiveLayoutForwardShape", &BijectiveLayout::ForwardShape) + .def_method("tir.BijectiveLayoutBackwardShape", &BijectiveLayout::BackwardShape); }); - -TVM_FFI_REGISTER_GLOBAL("tir.LayoutGetItem") - .set_body_typed([](Layout layout, int idx) -> std::string { - const LayoutAxis& axis = layout[idx]; - return axis.name(); - }); - -TVM_FFI_REGISTER_GLOBAL("tir.BijectiveLayout") - .set_body_typed([](Layout src_layout, Layout dst_layout) -> BijectiveLayout { - return BijectiveLayout(src_layout, dst_layout); - }); - -TVM_FFI_REGISTER_GLOBAL("tir.BijectiveLayoutForwardIndex") - .set_body_method(&BijectiveLayout::ForwardIndex); - -TVM_FFI_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardIndex") - .set_body_method(&BijectiveLayout::BackwardIndex); - -TVM_FFI_REGISTER_GLOBAL("tir.BijectiveLayoutForwardShape") - .set_body_method(&BijectiveLayout::ForwardShape); - -TVM_FFI_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardShape") - .set_body_method(&BijectiveLayout::BackwardShape); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index f6657451e511..3014f471baba 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -21,6 +21,7 @@ * \file expr.cc */ #include +#include #include #include #include @@ -79,8 +80,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ * `expr.dtype` field), this function allows the FFI conversions to be * explicitly invoked. */ -TVM_FFI_REGISTER_GLOBAL("tir.convert").set_body_typed([](Variant> expr) { - return expr; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.convert", + [](Variant> expr) { return expr; }); }); #define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \ @@ -163,13 +166,15 @@ Var Var::copy_with_dtype(DataType dtype) const { return Var(new_ptr); } -TVM_FFI_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, ffi::AnyView type, - Span span) { - if (type.as()) { - return Var(name_hint, type.cast(), span); - } else { - return Var(name_hint, type.cast(), span); - } +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.Var", [](String name_hint, ffi::AnyView type, Span span) { + if (type.as()) { + return Var(name_hint, type.cast(), span); + } else { + return Var(name_hint, type.cast(), span); + } + }); }); TVM_REGISTER_NODE_TYPE(VarNode); @@ -193,8 +198,10 @@ SizeVar::SizeVar(String name_hint, Type type_annotation, Span span) { data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](String s, DataType t, Span span) { - return SizeVar(s, t, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.SizeVar", + [](String s, DataType t, Span span) { return SizeVar(s, t, span); }); }); TVM_REGISTER_NODE_TYPE(SizeVarNode); @@ -219,10 +226,13 @@ IterVar::IterVar(Range dom, Var var, IterVarType t, String thread_tag, Span span data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("tir.IterVar") - .set_body_typed([](Range dom, Var var, int iter_type, String thread_tag, Span span) { - return IterVar(dom, var, static_cast(iter_type), thread_tag, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "tir.IterVar", [](Range dom, Var var, int iter_type, String thread_tag, Span span) { + return IterVar(dom, var, static_cast(iter_type), thread_tag, span); + }); +}); TVM_REGISTER_NODE_TYPE(IterVarNode); @@ -235,8 +245,10 @@ StringImm::StringImm(String value, Span span) { data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("tir.StringImm").set_body_typed([](String value, Span span) { - return StringImm(value, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.StringImm", + [](String value, Span span) { return StringImm(value, span); }); }); TVM_REGISTER_NODE_TYPE(StringImmNode); @@ -253,8 +265,11 @@ Cast::Cast(DataType t, PrimExpr value, Span span) { data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("tir.Cast").set_body_typed([](DataType dtype, PrimExpr value, Span span) { - return Cast(dtype, value, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.Cast", [](DataType dtype, PrimExpr value, Span span) { + return Cast(dtype, value, span); + }); }); TVM_REGISTER_NODE_TYPE(CastNode); @@ -262,8 +277,10 @@ TVM_REGISTER_NODE_TYPE(CastNode); // Add TVM_DEFINE_BINOP_CONSTRUCTOR(Add); -TVM_FFI_REGISTER_GLOBAL("tir.Add").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { - return Add(a, b, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.Add", + [](PrimExpr a, PrimExpr b, Span span) { return Add(a, b, span); }); }); TVM_REGISTER_NODE_TYPE(AddNode); @@ -271,8 +288,10 @@ TVM_REGISTER_NODE_TYPE(AddNode); // Sub TVM_DEFINE_BINOP_CONSTRUCTOR(Sub); -TVM_FFI_REGISTER_GLOBAL("tir.Sub").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { - return Sub(a, b, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.Sub", + [](PrimExpr a, PrimExpr b, Span span) { return Sub(a, b, span); }); }); TVM_REGISTER_NODE_TYPE(SubNode); @@ -280,8 +299,10 @@ TVM_REGISTER_NODE_TYPE(SubNode); // Mul TVM_DEFINE_BINOP_CONSTRUCTOR(Mul); -TVM_FFI_REGISTER_GLOBAL("tir.Mul").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { - return Mul(a, b, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.Mul", + [](PrimExpr a, PrimExpr b, Span span) { return Mul(a, b, span); }); }); TVM_REGISTER_NODE_TYPE(MulNode); @@ -289,8 +310,10 @@ TVM_REGISTER_NODE_TYPE(MulNode); // Div TVM_DEFINE_BINOP_CONSTRUCTOR(Div); -TVM_FFI_REGISTER_GLOBAL("tir.Div").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { - return Div(a, b, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.Div", + [](PrimExpr a, PrimExpr b, Span span) { return Div(a, b, span); }); }); TVM_REGISTER_NODE_TYPE(DivNode); @@ -298,8 +321,10 @@ TVM_REGISTER_NODE_TYPE(DivNode); // Mod TVM_DEFINE_BINOP_CONSTRUCTOR(Mod); -TVM_FFI_REGISTER_GLOBAL("tir.Mod").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { - return Mod(a, b, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.Mod", + [](PrimExpr a, PrimExpr b, Span span) { return Mod(a, b, span); }); }); TVM_REGISTER_NODE_TYPE(ModNode); @@ -307,8 +332,10 @@ TVM_REGISTER_NODE_TYPE(ModNode); // FloorDiv TVM_DEFINE_BINOP_CONSTRUCTOR(FloorDiv); -TVM_FFI_REGISTER_GLOBAL("tir.FloorDiv").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { - return FloorDiv(a, b, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.FloorDiv", + [](PrimExpr a, PrimExpr b, Span span) { return FloorDiv(a, b, span); }); }); TVM_REGISTER_NODE_TYPE(FloorDivNode); @@ -316,8 +343,10 @@ TVM_REGISTER_NODE_TYPE(FloorDivNode); // FloorMod TVM_DEFINE_BINOP_CONSTRUCTOR(FloorMod); -TVM_FFI_REGISTER_GLOBAL("tir.FloorMod").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { - return FloorMod(a, b, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.FloorMod", + [](PrimExpr a, PrimExpr b, Span span) { return FloorMod(a, b, span); }); }); TVM_REGISTER_NODE_TYPE(FloorModNode); @@ -325,8 +354,10 @@ TVM_REGISTER_NODE_TYPE(FloorModNode); // Min TVM_DEFINE_BINOP_CONSTRUCTOR(Min); -TVM_FFI_REGISTER_GLOBAL("tir.Min").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { - return Min(a, b, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.Min", + [](PrimExpr a, PrimExpr b, Span span) { return Min(a, b, span); }); }); TVM_REGISTER_NODE_TYPE(MinNode); @@ -334,8 +365,10 @@ TVM_REGISTER_NODE_TYPE(MinNode); // Max TVM_DEFINE_BINOP_CONSTRUCTOR(Max); -TVM_FFI_REGISTER_GLOBAL("tir.Max").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { - return Max(a, b, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.Max", + [](PrimExpr a, PrimExpr b, Span span) { return Max(a, b, span); }); }); TVM_REGISTER_NODE_TYPE(MaxNode); @@ -343,8 +376,9 @@ TVM_REGISTER_NODE_TYPE(MaxNode); // EQ TVM_DEFINE_CMPOP_CONSTRUCTOR(EQ); -TVM_FFI_REGISTER_GLOBAL("tir.EQ").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { - return EQ(a, b, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.EQ", [](PrimExpr a, PrimExpr b, Span span) { return EQ(a, b, span); }); }); TVM_REGISTER_NODE_TYPE(EQNode); @@ -352,8 +386,9 @@ TVM_REGISTER_NODE_TYPE(EQNode); // NE TVM_DEFINE_CMPOP_CONSTRUCTOR(NE); -TVM_FFI_REGISTER_GLOBAL("tir.NE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { - return NE(a, b, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.NE", [](PrimExpr a, PrimExpr b, Span span) { return NE(a, b, span); }); }); TVM_REGISTER_NODE_TYPE(NENode); @@ -361,8 +396,9 @@ TVM_REGISTER_NODE_TYPE(NENode); // LT TVM_DEFINE_CMPOP_CONSTRUCTOR(LT); -TVM_FFI_REGISTER_GLOBAL("tir.LT").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { - return LT(a, b, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.LT", [](PrimExpr a, PrimExpr b, Span span) { return LT(a, b, span); }); }); TVM_REGISTER_NODE_TYPE(LTNode); @@ -370,8 +406,9 @@ TVM_REGISTER_NODE_TYPE(LTNode); // LE TVM_DEFINE_CMPOP_CONSTRUCTOR(LE); -TVM_FFI_REGISTER_GLOBAL("tir.LE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { - return LE(a, b, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.LE", [](PrimExpr a, PrimExpr b, Span span) { return LE(a, b, span); }); }); TVM_REGISTER_NODE_TYPE(LENode); @@ -379,8 +416,9 @@ TVM_REGISTER_NODE_TYPE(LENode); // GT TVM_DEFINE_CMPOP_CONSTRUCTOR(GT); -TVM_FFI_REGISTER_GLOBAL("tir.GT").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { - return GT(a, b, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.GT", [](PrimExpr a, PrimExpr b, Span span) { return GT(a, b, span); }); }); TVM_REGISTER_NODE_TYPE(GTNode); @@ -388,8 +426,9 @@ TVM_REGISTER_NODE_TYPE(GTNode); // GE TVM_DEFINE_CMPOP_CONSTRUCTOR(GE); -TVM_FFI_REGISTER_GLOBAL("tir.GE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { - return GE(a, b, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.GE", [](PrimExpr a, PrimExpr b, Span span) { return GE(a, b, span); }); }); TVM_REGISTER_NODE_TYPE(GENode); @@ -411,8 +450,10 @@ And::And(PrimExpr a, PrimExpr b, Span span) { data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("tir.And").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { - return And(a, b, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.And", + [](PrimExpr a, PrimExpr b, Span span) { return And(a, b, span); }); }); TVM_REGISTER_NODE_TYPE(AndNode); @@ -434,8 +475,9 @@ Or::Or(PrimExpr a, PrimExpr b, Span span) { data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("tir.Or").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { - return Or(a, b, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.Or", [](PrimExpr a, PrimExpr b, Span span) { return Or(a, b, span); }); }); TVM_REGISTER_NODE_TYPE(OrNode); @@ -453,8 +495,9 @@ Not::Not(PrimExpr a, Span span) { data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("tir.Not").set_body_typed([](PrimExpr a, Span span) { - return Not(a, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.Not", [](PrimExpr a, Span span) { return Not(a, span); }); }); TVM_REGISTER_NODE_TYPE(NotNode); @@ -481,10 +524,13 @@ Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Sp data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("tir.Select") - .set_body_typed([](PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span) { - return Select(condition, true_value, false_value, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "tir.Select", [](PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span) { + return Select(condition, true_value, false_value, span); + }); +}); TVM_REGISTER_NODE_TYPE(SelectNode); @@ -520,10 +566,12 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("tir.Ramp") - .set_body_typed([](PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { - return Ramp(base, stride, lanes, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.Ramp", [](PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { + return Ramp(base, stride, lanes, span); + }); +}); TVM_REGISTER_NODE_TYPE(RampNode); @@ -553,10 +601,12 @@ Broadcast::Broadcast(PrimExpr value, PrimExpr lanes, Span span) { data_ = node; } -TVM_FFI_REGISTER_GLOBAL("tir.Broadcast") - .set_body_typed([](PrimExpr value, PrimExpr lanes, Span span) { - return Broadcast(value, lanes, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.Broadcast", [](PrimExpr value, PrimExpr lanes, Span span) { + return Broadcast(value, lanes, span); + }); +}); TVM_REGISTER_NODE_TYPE(BroadcastNode); @@ -575,9 +625,11 @@ Let::Let(Var var, PrimExpr value, PrimExpr body, Span span) { data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("tir.Let").set_body_typed([](Var var, PrimExpr value, PrimExpr body, - Span span) { - return Let(var, value, body, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.Let", [](Var var, PrimExpr value, PrimExpr body, Span span) { + return Let(var, value, body, span); + }); }); TVM_REGISTER_NODE_TYPE(LetNode); @@ -596,37 +648,40 @@ Call::Call(DataType dtype, RelaxExpr op, Array args, Span span) { data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("tir.Call") - .set_body_typed([](Optional dtype, RelaxExpr op, - Array> args, - Span span) { - Array prim_expr_args; - for (const auto& it : args) { - if (auto opt_str = it.as()) { - prim_expr_args.push_back(StringImm(opt_str.value())); - } else if (auto opt_dtype = it.as()) { - prim_expr_args.push_back(StringImm(ffi::DLDataTypeToString(opt_dtype.value()))); - } else if (const auto* iter_var = it.as()) { - prim_expr_args.push_back(iter_var->var); - } else if (const auto* br = it.as()) { - Array indices; - for (Range r : br->region) { - if (is_one(r->extent)) { - indices.push_back(r->min); - } else if (r->extent.as()) { - indices.push_back(tir::Ramp(r->min, make_const(r->min->dtype, 1), r->extent)); - } else { - LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " - << GetRef(br); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "tir.Call", + [](Optional dtype, RelaxExpr op, + Array> args, Span span) { + Array prim_expr_args; + for (const auto& it : args) { + if (auto opt_str = it.as()) { + prim_expr_args.push_back(StringImm(opt_str.value())); + } else if (auto opt_dtype = it.as()) { + prim_expr_args.push_back(StringImm(ffi::DLDataTypeToString(opt_dtype.value()))); + } else if (const auto* iter_var = it.as()) { + prim_expr_args.push_back(iter_var->var); + } else if (const auto* br = it.as()) { + Array indices; + for (Range r : br->region) { + if (is_one(r->extent)) { + indices.push_back(r->min); + } else if (r->extent.as()) { + indices.push_back(tir::Ramp(r->min, make_const(r->min->dtype, 1), r->extent)); + } else { + LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " + << GetRef(br); + } } + prim_expr_args.push_back(BufferLoad(br->buffer, indices)); + } else { + prim_expr_args.push_back(Downcast(it)); } - prim_expr_args.push_back(BufferLoad(br->buffer, indices)); - } else { - prim_expr_args.push_back(Downcast(it)); } - } - return Call(dtype.value_or(DataType::Void()), op, prim_expr_args, span); - }); + return Call(dtype.value_or(DataType::Void()), op, prim_expr_args, span); + }); +}); TVM_REGISTER_NODE_TYPE(CallNode); @@ -671,10 +726,11 @@ PrimExpr Shuffle::ExtractElement(PrimExpr vector, int index, Span span) { return Shuffle({vector}, {Integer(index)}, span); } -TVM_FFI_REGISTER_GLOBAL("tir.Shuffle") - .set_body_typed([](Array vectors, Array indices, Span span) { - return Shuffle(vectors, indices, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.Shuffle", [](Array vectors, Array indices, + Span span) { return Shuffle(vectors, indices, span); }); +}); TVM_REGISTER_NODE_TYPE(ShuffleNode); @@ -731,14 +787,15 @@ Array CommReducerNode::operator()(Array a, Array b return Substitute(this->result, value_map); } -TVM_FFI_REGISTER_GLOBAL("tir.CommReducer") - .set_body_typed([](Array lhs, Array rhs, Array result, - Array identity_element, Span span) { - return CommReducer(lhs, rhs, result, identity_element, span); - }); - -TVM_FFI_REGISTER_GLOBAL("tir.CommReducerCombine") - .set_body_method(&tir::CommReducerNode::operator()); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.CommReducer", + [](Array lhs, Array rhs, Array result, + Array identity_element, + Span span) { return CommReducer(lhs, rhs, result, identity_element, span); }) + .def_method("tir.CommReducerCombine", &tir::CommReducerNode::operator()); +}); TVM_REGISTER_NODE_TYPE(CommReducerNode); @@ -777,11 +834,14 @@ Reduce::Reduce(CommReducer combiner, Array source, Array axis data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("tir.Reduce") - .set_body_typed([](CommReducer combiner, Array source, Array axis, - PrimExpr condition, int value_index, Array init, Span span) { - return Reduce(combiner, source, axis, condition, value_index, init, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.Reduce", + [](CommReducer combiner, Array source, Array axis, + PrimExpr condition, int value_index, Array init, Span span) { + return Reduce(combiner, source, axis, condition, value_index, init, span); + }); +}); TVM_REGISTER_NODE_TYPE(ReduceNode); @@ -852,9 +912,12 @@ BufferLoad::BufferLoad(Buffer buffer, Array indices, Optional indices, Optional predicate, - Span span) { return BufferLoad(buffer, indices, predicate, span); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.BufferLoad", + [](Buffer buffer, Array indices, Optional predicate, + Span span) { return BufferLoad(buffer, indices, predicate, span); }); +}); TVM_REGISTER_NODE_TYPE(BufferLoadNode); @@ -868,10 +931,13 @@ ProducerLoad::ProducerLoad(DataProducer producer, Array indices, Span data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("tir.ProducerLoad") - .set_body_typed([](DataProducer producer, Array indices, Span span) { - return ProducerLoad(producer, indices, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.ProducerLoad", + [](DataProducer producer, Array indices, Span span) { + return ProducerLoad(producer, indices, span); + }); +}); TVM_REGISTER_NODE_TYPE(ProducerLoadNode); diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 3996435e8e84..a05efac775df 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -161,19 +161,20 @@ Optional TensorIntrin::Get(String name, bool allow_missing) { TVM_REGISTER_NODE_TYPE(TensorIntrinNode); -TVM_FFI_REGISTER_GLOBAL("tir.PrimFunc") - .set_body_typed([](Array params, Stmt body, Type ret_type, - Map buffer_map, DictAttrs attrs, Span span) { - return PrimFunc(params, body, ret_type, buffer_map, attrs, span); - }); - -TVM_FFI_REGISTER_GLOBAL("tir.TensorIntrin") - .set_body_typed([](PrimFunc desc_func, PrimFunc intrin_func) { - return TensorIntrin(desc_func, intrin_func); - }); - -TVM_FFI_REGISTER_GLOBAL("tir.TensorIntrinRegister").set_body_typed(TensorIntrin::Register); -TVM_FFI_REGISTER_GLOBAL("tir.TensorIntrinGet").set_body_typed(TensorIntrin::Get); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.PrimFunc", + [](Array params, Stmt body, Type ret_type, Map buffer_map, + DictAttrs attrs, + Span span) { return PrimFunc(params, body, ret_type, buffer_map, attrs, span); }) + .def("tir.TensorIntrin", + [](PrimFunc desc_func, PrimFunc intrin_func) { + return TensorIntrin(desc_func, intrin_func); + }) + .def("tir.TensorIntrinRegister", TensorIntrin::Register) + .def("tir.TensorIntrinGet", TensorIntrin::Get); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 1596be567fc9..01802dd4e444 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -422,39 +422,37 @@ IndexMap Substitute(const IndexMap& index_map, TVM_REGISTER_NODE_TYPE(IndexMapNode); -TVM_FFI_REGISTER_GLOBAL("tir.IndexMap") - .set_body_typed([](Array initial_indices, Array final_indices, - Optional inverse_index_map) { - return IndexMap(initial_indices, final_indices, inverse_index_map); - }); - -TVM_FFI_REGISTER_GLOBAL("tir.IndexMapMapIndices") - .set_body_typed([](IndexMap map, Array indices) { - arith::Analyzer analyzer; - return map->MapIndices(indices, &analyzer); - }); - -TVM_FFI_REGISTER_GLOBAL("tir.IndexMapMapShape") - .set_body_typed([](IndexMap map, Array shape) { - arith::Analyzer analyzer; - return map->MapShape(shape, &analyzer); - }); - -TVM_FFI_REGISTER_GLOBAL("tir.IndexMapInverse") - .set_body_typed([](IndexMap map, Array initial_ranges) { - arith::Analyzer analyzer; - return map.Inverse(initial_ranges, &analyzer); - }); - -TVM_FFI_REGISTER_GLOBAL("tir.IndexMapMapNDArray") - .set_body_typed([](IndexMap map, runtime::NDArray arr) { return map->MapNDArray(arr); }); - -TVM_FFI_REGISTER_GLOBAL("tir.IndexMapNonSurjectiveInverse") - .set_body_typed([](IndexMap forward, Array initial_ranges) { - arith::Analyzer analyzer; - auto result = forward.NonSurjectiveInverse(initial_ranges, &analyzer); - return Array{result.first, result.second}; - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.IndexMap", + [](Array initial_indices, Array final_indices, + Optional inverse_index_map) { + return IndexMap(initial_indices, final_indices, inverse_index_map); + }) + .def("tir.IndexMapMapIndices", + [](IndexMap map, Array indices) { + arith::Analyzer analyzer; + return map->MapIndices(indices, &analyzer); + }) + .def("tir.IndexMapMapShape", + [](IndexMap map, Array shape) { + arith::Analyzer analyzer; + return map->MapShape(shape, &analyzer); + }) + .def("tir.IndexMapInverse", + [](IndexMap map, Array initial_ranges) { + arith::Analyzer analyzer; + return map.Inverse(initial_ranges, &analyzer); + }) + .def("tir.IndexMapMapNDArray", + [](IndexMap map, runtime::NDArray arr) { return map->MapNDArray(arr); }) + .def("tir.IndexMapNonSurjectiveInverse", [](IndexMap forward, Array initial_ranges) { + arith::Analyzer analyzer; + auto result = forward.NonSurjectiveInverse(initial_ranges, &analyzer); + return Array{result.first, result.second}; + }); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/py_functor.cc b/src/tir/ir/py_functor.cc index 6382a779927f..4859e46807be 100644 --- a/src/tir/ir/py_functor.cc +++ b/src/tir/ir/py_functor.cc @@ -828,44 +828,44 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_REGISTER_NODE_TYPE(PyStmtExprVisitorNode); TVM_REGISTER_NODE_TYPE(PyStmtExprMutatorNode); -TVM_FFI_REGISTER_GLOBAL("tir.MakePyStmtExprVisitor") - .set_body_typed(PyStmtExprVisitor::MakePyStmtExprVisitor); -TVM_FFI_REGISTER_GLOBAL("tir.MakePyStmtExprMutator") - .set_body_typed(PyStmtExprMutator::MakePyStmtExprMutator); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.MakePyStmtExprVisitor", PyStmtExprVisitor::MakePyStmtExprVisitor) + .def("tir.MakePyStmtExprMutator", PyStmtExprMutator::MakePyStmtExprMutator); +}); // StmtExprVisitor -TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprVisitorDefaultVisitExpr") - .set_body_typed([](PyStmtExprVisitor visitor, const PrimExpr& expr) { - visitor->DefaultVisitExpr(expr); - }); -TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprVisitorDefaultVisitStmt") - .set_body_typed([](PyStmtExprVisitor visitor, const Stmt& stmt) { - visitor->DefaultVisitStmt(stmt); - }); -TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprVisitorVisitStmt") - .set_body_typed([](PyStmtExprVisitor visitor, const Stmt& stmt) { visitor->VisitStmt(stmt); }); -TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprVisitorVisitExpr") - .set_body_typed([](PyStmtExprVisitor visitor, const PrimExpr& expr) { - visitor->VisitExpr(expr); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.PyStmtExprVisitorDefaultVisitExpr", + [](PyStmtExprVisitor visitor, const PrimExpr& expr) { visitor->DefaultVisitExpr(expr); }) + .def("tir.PyStmtExprVisitorDefaultVisitStmt", + [](PyStmtExprVisitor visitor, const Stmt& stmt) { visitor->DefaultVisitStmt(stmt); }) + .def("tir.PyStmtExprVisitorVisitStmt", + [](PyStmtExprVisitor visitor, const Stmt& stmt) { visitor->VisitStmt(stmt); }) + .def("tir.PyStmtExprVisitorVisitExpr", + [](PyStmtExprVisitor visitor, const PrimExpr& expr) { visitor->VisitExpr(expr); }); +}); // StmtExprMutator -TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprMutatorDefaultVisitExpr") - .set_body_typed([](PyStmtExprMutator mutator, const PrimExpr& expr) { - return mutator->DefaultVisitExpr(expr); - }); -TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprMutatorDefaultVisitStmt") - .set_body_typed([](PyStmtExprMutator mutator, const Stmt& stmt) { - return mutator->DefaultVisitStmt(stmt); - }); -TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprMutatorVisitExpr") - .set_body_typed([](PyStmtExprMutator mutator, const PrimExpr& expr) { - return mutator->VisitExpr(expr); - }); -TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprMutatorVisitStmt") - .set_body_typed([](PyStmtExprMutator mutator, const Stmt& stmt) { - return mutator->VisitStmt(stmt); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.PyStmtExprMutatorDefaultVisitExpr", + [](PyStmtExprMutator mutator, const PrimExpr& expr) { + return mutator->DefaultVisitExpr(expr); + }) + .def("tir.PyStmtExprMutatorDefaultVisitStmt", + [](PyStmtExprMutator mutator, const Stmt& stmt) { + return mutator->DefaultVisitStmt(stmt); + }) + .def("tir.PyStmtExprMutatorVisitExpr", + [](PyStmtExprMutator mutator, const PrimExpr& expr) { return mutator->VisitExpr(expr); }) + .def("tir.PyStmtExprMutatorVisitStmt", + [](PyStmtExprMutator mutator, const Stmt& stmt) { return mutator->VisitStmt(stmt); }); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc index 00313bfd0227..bf8289036a30 100644 --- a/src/tir/ir/script/script_complete.cc +++ b/src/tir/ir/script/script_complete.cc @@ -25,6 +25,7 @@ #include "./script_complete.h" #include +#include #include #include @@ -160,7 +161,10 @@ PrimFunc ScriptComplete(PrimFunc func, const Array& root_allocates) { } } -TVM_FFI_REGISTER_GLOBAL("script.Complete").set_body_typed(ScriptComplete); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("script.Complete", ScriptComplete); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index 86ed65c4905d..ac83477b1daa 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -22,6 +22,7 @@ * \brief Specialize parameters of PrimFunc. */ #include +#include #include #include #include @@ -432,7 +433,10 @@ PrimFunc Specialize(PrimFunc func, const Map>& pa /**************** FFI ****************/ -TVM_FFI_REGISTER_GLOBAL("tir.Specialize").set_body_typed(Specialize); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.Specialize", Specialize); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 6be07368972d..bb8d6c9fd1f6 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include #include @@ -73,10 +74,12 @@ LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) { data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("tir.LetStmt") - .set_body_typed([](Var var, PrimExpr value, Stmt body, Span span) { - return LetStmt(var, value, body, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.LetStmt", [](Var var, PrimExpr value, Stmt body, Span span) { + return LetStmt(var, value, body, span); + }); +}); TVM_REGISTER_NODE_TYPE(LetStmtNode); @@ -91,14 +94,18 @@ AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, S data_ = std::move(n); } -TVM_FFI_REGISTER_GLOBAL("tir.AttrStmt") - .set_body_typed([](Any node, String attr_key, PrimExpr value, Stmt body, Span span) { - // when node is a POD data type like int or bool, first convert to primexpr. - if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - return AttrStmt(node.cast(), attr_key, value, body, span); - } - return AttrStmt(node.cast(), attr_key, value, body, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.AttrStmt", + [](Any node, String attr_key, PrimExpr value, Stmt body, Span span) { + // when node is a POD data type like int or bool, first convert to + // primexpr. + if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + return AttrStmt(node.cast(), attr_key, value, body, span); + } + return AttrStmt(node.cast(), attr_key, value, body, span); + }); +}); TVM_REGISTER_NODE_TYPE(AttrStmtNode); @@ -121,10 +128,13 @@ AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span spa TVM_REGISTER_NODE_TYPE(AssertStmtNode); -TVM_FFI_REGISTER_GLOBAL("tir.AssertStmt") - .set_body_typed([](PrimExpr condition, StringImm message, Stmt body, Span span) { - return AssertStmt(condition, message, body, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.AssertStmt", + [](PrimExpr condition, StringImm message, Stmt body, Span span) { + return AssertStmt(condition, message, body, span); + }); +}); // For For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, @@ -176,12 +186,15 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("tir.For").set_body_typed( - [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, Stmt body, - Optional thread_binding, Optional> annotations, Span span) { - return For(loop_var, min, extent, static_cast(kind), body, thread_binding, - annotations.value_or(Map()), span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.For", [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, + Stmt body, Optional thread_binding, + Optional> annotations, Span span) { + return For(loop_var, min, extent, static_cast(kind), body, thread_binding, + annotations.value_or(Map()), span); + }); +}); TVM_REGISTER_NODE_TYPE(ForNode); @@ -220,8 +233,11 @@ While::While(PrimExpr condition, Stmt body, Span span) { data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("tir.While").set_body_typed([](PrimExpr condition, Stmt body, Span span) { - return While(condition, body, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.While", [](PrimExpr condition, Stmt body, Span span) { + return While(condition, body, span); + }); }); TVM_REGISTER_NODE_TYPE(WhileNode); @@ -270,11 +286,14 @@ int64_t AllocateNode::ConstantAllocationSize(const Array& extents) { return static_cast(result); } -TVM_FFI_REGISTER_GLOBAL("tir.Allocate") - .set_body_typed([](Var buffer_var, DataType type, Array extents, PrimExpr condition, - Stmt body, Map annotations, Span span) { - return Allocate(buffer_var, type, extents, condition, body, annotations, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "tir.Allocate", [](Var buffer_var, DataType type, Array extents, PrimExpr condition, + Stmt body, Map annotations, Span span) { + return Allocate(buffer_var, type, extents, condition, body, annotations, span); + }); +}); TVM_REGISTER_NODE_TYPE(AllocateNode); @@ -331,13 +350,16 @@ int64_t AllocateConstNode::ConstantAllocationSize(const Array& extents } return static_cast(result); } -TVM_FFI_REGISTER_GLOBAL("tir.AllocateConst") - .set_body_typed([](Var buffer_var, DataType dtype, Array extents, - ObjectRef data_or_idx, Stmt body, Optional> annotations, - Span span) { - return AllocateConst(buffer_var, dtype, extents, data_or_idx, body, annotations.value_or({}), - span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "tir.AllocateConst", + [](Var buffer_var, DataType dtype, Array extents, ObjectRef data_or_idx, Stmt body, + Optional> annotations, Span span) { + return AllocateConst(buffer_var, dtype, extents, data_or_idx, body, + annotations.value_or({}), span); + }); +}); TVM_REGISTER_NODE_TYPE(AllocateConstNode); @@ -350,8 +372,11 @@ DeclBuffer::DeclBuffer(Buffer buffer, Stmt body, Span span) { data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("tir.DeclBuffer").set_body_typed([](Buffer buffer, Stmt body, Span span) { - return DeclBuffer(buffer, body, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.DeclBuffer", [](Buffer buffer, Stmt body, Span span) { + return DeclBuffer(buffer, body, span); + }); }); TVM_REGISTER_NODE_TYPE(DeclBufferNode); @@ -383,8 +408,10 @@ SeqStmt::SeqStmt(Array seq, Span span) { data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array seq, Span span) { - return SeqStmt(std::move(seq), span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.SeqStmt", + [](Array seq, Span span) { return SeqStmt(std::move(seq), span); }); }); TVM_REGISTER_NODE_TYPE(SeqStmtNode); @@ -404,10 +431,13 @@ IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Optional else_c TVM_REGISTER_NODE_TYPE(IfThenElseNode); -TVM_FFI_REGISTER_GLOBAL("tir.IfThenElse") - .set_body_typed([](PrimExpr condition, Stmt then_case, Stmt else_case, Span span) { - return IfThenElse(condition, then_case, else_case, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.IfThenElse", + [](PrimExpr condition, Stmt then_case, Stmt else_case, Span span) { + return IfThenElse(condition, then_case, else_case, span); + }); +}); // Evaluate Evaluate::Evaluate(PrimExpr value, Span span) { @@ -419,8 +449,10 @@ Evaluate::Evaluate(PrimExpr value, Span span) { data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value, Span span) { - return Evaluate(value, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.Evaluate", + [](PrimExpr value, Span span) { return Evaluate(value, span); }); }); TVM_REGISTER_NODE_TYPE(EvaluateNode); @@ -501,10 +533,13 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("tir.BufferStore") - .set_body_typed([](Buffer buffer, PrimExpr value, Array indices, - Optional predicate, - Span span) { return BufferStore(buffer, value, indices, predicate, span); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "tir.BufferStore", + [](Buffer buffer, PrimExpr value, Array indices, Optional predicate, + Span span) { return BufferStore(buffer, value, indices, predicate, span); }); +}); TVM_REGISTER_NODE_TYPE(BufferStoreNode); @@ -514,9 +549,13 @@ BufferRealize::BufferRealize(Buffer buffer, Array bounds, PrimExpr condit data_ = make_object(buffer, bounds, condition, body, span); } -TVM_FFI_REGISTER_GLOBAL("tir.BufferRealize") - .set_body_typed([](Buffer buffer, Array bounds, PrimExpr condition, Stmt body, - Span span) { return BufferRealize(buffer, bounds, condition, body, span); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.BufferRealize", [](Buffer buffer, Array bounds, + PrimExpr condition, Stmt body, Span span) { + return BufferRealize(buffer, bounds, condition, body, span); + }); +}); TVM_REGISTER_NODE_TYPE(BufferRealizeNode); @@ -568,8 +607,11 @@ BufferRegion BufferRegion::FromPoint(Buffer buffer, Array indices) { return BufferRegion(buffer, region); } -TVM_FFI_REGISTER_GLOBAL("tir.BufferRegion").set_body_typed([](Buffer buffer, Array region) { - return BufferRegion(buffer, region); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.BufferRegion", [](Buffer buffer, Array region) { + return BufferRegion(buffer, region); + }); }); TVM_REGISTER_NODE_TYPE(BufferRegionNode); @@ -625,10 +667,12 @@ MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("tir.MatchBufferRegion") - .set_body_typed([](Buffer buffer, BufferRegion source) { - return MatchBufferRegion(buffer, source); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.MatchBufferRegion", [](Buffer buffer, BufferRegion source) { + return MatchBufferRegion(buffer, source); + }); +}); TVM_REGISTER_NODE_TYPE(MatchBufferRegionNode); @@ -650,14 +694,17 @@ Block::Block(Array iter_vars, Array reads, Array iter_vars, Array reads, - Array writes, String name_hint, Stmt body, Optional init, - Array alloc_buffers, Array match_buffers, - Map annotations, Span span) { - return Block(iter_vars, reads, writes, name_hint, body, init, alloc_buffers, match_buffers, - annotations, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "tir.Block", + [](Array iter_vars, Array reads, Array writes, + String name_hint, Stmt body, Optional init, Array alloc_buffers, + Array match_buffers, Map annotations, Span span) { + return Block(iter_vars, reads, writes, name_hint, body, init, alloc_buffers, match_buffers, + annotations, span); + }); +}); TVM_REGISTER_NODE_TYPE(BlockNode); @@ -674,10 +721,13 @@ BlockRealize::BlockRealize(Array values, PrimExpr predicate, Block blo data_ = std::move(node); } -TVM_FFI_REGISTER_GLOBAL("tir.BlockRealize") - .set_body_typed([](Array iter_values, PrimExpr predicate, Block block, Span span) { - return BlockRealize(iter_values, predicate, block, span); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.BlockRealize", [](Array iter_values, PrimExpr predicate, + Block block, Span span) { + return BlockRealize(iter_values, predicate, block, span); + }); +}); TVM_REGISTER_NODE_TYPE(BlockRealizeNode); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 7ecb4558506c..0aa699472d29 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -20,6 +20,7 @@ * \file stmt_functor.cc */ #include +#include #include #include #include @@ -832,24 +833,26 @@ PrimExpr SubstituteWithDataTypeLegalization(PrimExpr expr, return IRSubstituteWithDataTypeLegalization(vmap)(std::move(expr)); } -TVM_FFI_REGISTER_GLOBAL("tir.IRTransform").set_body_typed(IRTransform); - -TVM_FFI_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, ffi::Function f) { - tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); }); -}); - -TVM_FFI_REGISTER_GLOBAL("tir.PreOrderVisit").set_body_typed([](ObjectRef node, ffi::Function f) { - tir::PreOrderVisit(node, [f](const ObjectRef& n) { return f(n).cast(); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.IRTransform", IRTransform) + .def("tir.PostOrderVisit", + [](ObjectRef node, ffi::Function f) { + tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); }); + }) + .def("tir.PreOrderVisit", + [](ObjectRef node, ffi::Function f) { + tir::PreOrderVisit(node, [f](const ObjectRef& n) { return f(n).cast(); }); + }) + .def("tir.Substitute", [](ObjectRef node, Map vmap) -> ObjectRef { + if (node->IsInstance()) { + return Substitute(Downcast(node), vmap); + } else { + return Substitute(Downcast(node), vmap); + } + }); }); -TVM_FFI_REGISTER_GLOBAL("tir.Substitute") - .set_body_typed([](ObjectRef node, Map vmap) -> ObjectRef { - if (node->IsInstance()) { - return Substitute(Downcast(node), vmap); - } else { - return Substitute(Downcast(node), vmap); - } - }); - } // namespace tir } // namespace tvm diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index b8a91cfc127e..704d6310d0c5 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -150,15 +150,18 @@ TVM_FFI_STATIC_INIT_BLOCK({ PrimFuncPassNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(PrimFuncPassNode); -TVM_FFI_REGISTER_GLOBAL("tir.transform.CreatePrimFuncPass") - .set_body_typed( - [](ffi::TypedFunction, IRModule, PassContext)> pass_func, - PassInfo pass_info) { - auto wrapped_pass_func = [pass_func](PrimFunc func, IRModule mod, PassContext ctx) { - return pass_func(ffi::RValueRef(std::move(func)), mod, ctx); - }; - return PrimFuncPass(wrapped_pass_func, pass_info); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "tir.transform.CreatePrimFuncPass", + [](ffi::TypedFunction, IRModule, PassContext)> pass_func, + PassInfo pass_info) { + auto wrapped_pass_func = [pass_func](PrimFunc func, IRModule mod, PassContext ctx) { + return pass_func(ffi::RValueRef(std::move(func)), mod, ctx); + }; + return PrimFuncPass(wrapped_pass_func, pass_info); + }); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 41e090c58542..9a520073a669 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -24,6 +24,7 @@ */ #include +#include #include #include #include @@ -245,7 +246,10 @@ PrimExpr ret(PrimExpr value, Span span) { return tir::Call(value.dtype(), tir::builtin::ret(), {value}, span); } -TVM_FFI_REGISTER_GLOBAL("tir.ret").set_body_typed(ret); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.ret", ret); +}); // maximum and min limits PrimExpr max_value(const DataType& dtype, Span span) { @@ -802,8 +806,10 @@ PrimExpr bitwise_neg(PrimExpr a, Span span) { return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, span); } -TVM_FFI_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a, Span span) { - return bitwise_neg(a, span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.bitwise_not", + [](PrimExpr a, Span span) { return bitwise_neg(a, span); }); }); // pow @@ -1112,49 +1118,39 @@ TVM_TIR_REGISTER_OP("TVMBackendFreeWorkspace") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); // expose basic functions to node namespace -TVM_FFI_REGISTER_GLOBAL("node._const").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - if (auto opt = args[0].try_cast()) { - *ret = tir::make_const(args[1].cast(), *opt, args[2].cast()); - } else if (auto opt = args[0].try_cast()) { - *ret = tir::make_const(args[1].cast(), *opt, args[2].cast()); - } else { - LOG(FATAL) << "First argument to tvm.tir.const must be int, float, or bool, " - << "but instead received argument with type code " << args[0].GetTypeKey(); - } +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("node._const", + [](ffi::PackedArgs args, ffi::Any* ret) { + if (auto opt = args[0].try_cast()) { + *ret = tir::make_const(args[1].cast(), *opt, args[2].cast()); + } else if (auto opt = args[0].try_cast()) { + *ret = tir::make_const(args[1].cast(), *opt, args[2].cast()); + } else { + LOG(FATAL) << "First argument to tvm.tir.const must be int, float, or bool, " + << "but instead received argument with type code " + << args[0].GetTypeKey(); + } + }) + .def("node.LargeUIntImm", LargeUIntImm) + .def("tir.min_value", min_value) + .def("tir.max_value", max_value) + .def("tir.infinity", infinity) + .def("tir.abs", tvm::abs) + .def("tir.likely", tvm::likely) + .def("tir.isnan", tvm::isnan) + .def("tir.isfinite", tvm::isfinite) + .def("tir.isinf", tvm::isinf) + .def("tir.floor", tvm::floor) + .def("tir.ceil", tvm::ceil) + .def("tir.round", tvm::round) + .def("tir.nearbyint", tvm::nearbyint) + .def("tir.trunc", tvm::trunc) + .def("tir._cast", tvm::cast) + .def("tir.reinterpret", tvm::reinterpret); }); -TVM_FFI_REGISTER_GLOBAL("node.LargeUIntImm").set_body_typed(LargeUIntImm); - -TVM_FFI_REGISTER_GLOBAL("tir.min_value").set_body_typed(min_value); - -TVM_FFI_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value); - -TVM_FFI_REGISTER_GLOBAL("tir.infinity").set_body_typed(infinity); - -TVM_FFI_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs); - -TVM_FFI_REGISTER_GLOBAL("tir.likely").set_body_typed(tvm::likely); - -TVM_FFI_REGISTER_GLOBAL("tir.isnan").set_body_typed(tvm::isnan); - -TVM_FFI_REGISTER_GLOBAL("tir.isfinite").set_body_typed(tvm::isfinite); - -TVM_FFI_REGISTER_GLOBAL("tir.isinf").set_body_typed(tvm::isinf); - -TVM_FFI_REGISTER_GLOBAL("tir.floor").set_body_typed(tvm::floor); - -TVM_FFI_REGISTER_GLOBAL("tir.ceil").set_body_typed(tvm::ceil); - -TVM_FFI_REGISTER_GLOBAL("tir.round").set_body_typed(tvm::round); - -TVM_FFI_REGISTER_GLOBAL("tir.nearbyint").set_body_typed(tvm::nearbyint); - -TVM_FFI_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc); - -TVM_FFI_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast); - -TVM_FFI_REGISTER_GLOBAL("tir.reinterpret").set_body_typed(tvm::reinterpret); - // operator overloading, smarter than make #define REGISTER_MAKE_BINARY_OP(Node, Func) \ TVM_FFI_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \ @@ -1204,13 +1200,14 @@ REGISTER_MAKE_BIT_OP(bitwise_xor, bitwise_xor); REGISTER_MAKE_BIT_OP(left_shift, left_shift); // NOLINT(*) REGISTER_MAKE_BIT_OP(right_shift, right_shift); -TVM_FFI_REGISTER_GLOBAL("tir._OpIfThenElse") - .set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) { - return if_then_else(cond, true_value, false_value, span); - }); - -TVM_FFI_REGISTER_GLOBAL("tir.const_true").set_body_typed([](DataType t, Span span) { - return const_true(t.lanes(), span); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir._OpIfThenElse", + [](PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) { + return if_then_else(cond, true_value, false_value, span); + }) + .def("tir.const_true", [](DataType t, Span span) { return const_true(t.lanes(), span); }); }); PrimExpr fast_erf_float_expr(PrimExpr arg, int bits) { diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 9d23661bace3..ea4a2bf41f0e 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -335,10 +335,13 @@ bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, return CheckReductionBlockErrorCode(self, block_sref, scope_root_sref) == 0; } -TVM_FFI_REGISTER_GLOBAL("tir.schedule.IsReductionBlock") - .set_body_typed([](Schedule sch, BlockRV block_rv, BlockRV scope_block_rv) { - return IsReductionBlock(sch->state(), sch->GetSRef(block_rv), sch->GetSRef(scope_block_rv)); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "tir.schedule.IsReductionBlock", [](Schedule sch, BlockRV block_rv, BlockRV scope_block_rv) { + return IsReductionBlock(sch->state(), sch->GetSRef(block_rv), sch->GetSRef(scope_block_rv)); + }); +}); void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { @@ -871,10 +874,12 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr } } -TVM_FFI_REGISTER_GLOBAL("tir.schedule.GetBlockRealize") - .set_body_typed([](Schedule sch, BlockRV block_rv) { - return GetBlockRealize(sch->state(), sch->GetSRef(block_rv)); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.schedule.GetBlockRealize", [](Schedule sch, BlockRV block_rv) { + return GetBlockRealize(sch->state(), sch->GetSRef(block_rv)); + }); +}); IterVarType GetLoopIterType(const StmtSRef& loop_sref) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); @@ -1490,10 +1495,12 @@ bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { return true; } -TVM_FFI_REGISTER_GLOBAL("tir.schedule.IsTrivialBinding") - .set_body_typed([](Schedule sch, BlockRV block_rv) { - return IsTrivialBinding(sch->state(), sch->GetSRef(block_rv)); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.schedule.IsTrivialBinding", [](Schedule sch, BlockRV block_rv) { + return IsTrivialBinding(sch->state(), sch->GetSRef(block_rv)); + }); +}); bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref) { if (HasBeenMultiLevelTiled(block_sref)) { @@ -1898,11 +1905,15 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, return TensorizeInfo(ret); } -TVM_FFI_REGISTER_GLOBAL("tir.schedule.IsSpatialPrimFunc").set_body_typed(IsSpatialPrimFunc); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping") - .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func, bool allow_padding) { - return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func, allow_padding); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.schedule.IsSpatialPrimFunc", IsSpatialPrimFunc) + .def("tir.schedule.GetTensorizeLoopMapping", [](Schedule sch, BlockRV block, + PrimFunc desc_func, bool allow_padding) { + return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func, allow_padding); + }); +}); /******** Auto Tensorization ********/ @@ -2128,30 +2139,31 @@ Optional GetAutoTensorizeMappingInfo(const tir::Schedu TVM_REGISTER_NODE_TYPE(AutoTensorizeMappingInfoNode); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.GetAutoTensorizeMappingInfo") - .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) { - return GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block), desc_func); - }); - -TVM_FFI_REGISTER_GLOBAL("tir.schedule.HasBlock").set_body_typed(HasBlock); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.IsOutputBlock") - .set_body_typed([](Schedule sch, BlockRV block) { - auto state = sch->state(); - auto block_sref = sch->GetSRef(block); - return IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false)); - }); - -TVM_FFI_REGISTER_GLOBAL("tir.schedule.GetLoopIterType") - .set_body_typed([](Schedule sch, LoopRV loop) -> String { - IterVarType kind = GetLoopIterType(sch->GetSRef(loop)); - if (kind == kDataPar) { - return "S"; - } else if (kind == kCommReduce) { - return "R"; - } else { - return "O"; - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.schedule.GetAutoTensorizeMappingInfo", + [](Schedule sch, BlockRV block, PrimFunc desc_func) { + return GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block), desc_func); + }) + .def("tir.schedule.HasBlock", HasBlock) + .def("tir.schedule.IsOutputBlock", + [](Schedule sch, BlockRV block) { + auto state = sch->state(); + auto block_sref = sch->GetSRef(block); + return IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false)); + }) + .def("tir.schedule.GetLoopIterType", [](Schedule sch, LoopRV loop) -> String { + IterVarType kind = GetLoopIterType(sch->GetSRef(loop)); + if (kind == kDataPar) { + return "S"; + } else if (kind == kCommReduce) { + return "R"; + } else { + return "O"; + } + }); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/layout.cc b/src/tir/schedule/analysis/layout.cc index 13b35582eefc..9f1a68e92e84 100644 --- a/src/tir/schedule/analysis/layout.cc +++ b/src/tir/schedule/analysis/layout.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -238,12 +240,14 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& return IndexMap::FromFunc(ndim, f_alter_layout, inverse_index_map); } -TVM_FFI_REGISTER_GLOBAL("tir.schedule.SuggestIndexMap") - .set_body_typed([](Buffer buffer, Array indices, Array loops, - PrimExpr predicate) { - arith::Analyzer analyzer; - return SuggestIndexMap(buffer, indices, loops, predicate, &analyzer); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.schedule.SuggestIndexMap", [](Buffer buffer, Array indices, + Array loops, PrimExpr predicate) { + arith::Analyzer analyzer; + return SuggestIndexMap(buffer, indices, loops, predicate, &analyzer); + }); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/instruction.cc b/src/tir/schedule/instruction.cc index 7851c697a144..63da75851b63 100644 --- a/src/tir/schedule/instruction.cc +++ b/src/tir/schedule/instruction.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "./utils.h" namespace tvm { @@ -105,12 +107,14 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(InstructionNode); TVM_REGISTER_NODE_TYPE(InstructionKindNode); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.InstructionKindGet").set_body_typed(InstructionKind::Get); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.Instruction") - .set_body_typed([](InstructionKind kind, Array inputs, Array attrs, - Array outputs) -> Instruction { - return Instruction(kind, inputs, attrs, outputs); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.schedule.InstructionKindGet", InstructionKind::Get) + .def("tir.schedule.Instruction", + [](InstructionKind kind, Array inputs, Array attrs, Array outputs) + -> Instruction { return Instruction(kind, inputs, attrs, outputs); }); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/decompose_padding.cc b/src/tir/schedule/primitive/decompose_padding.cc index ea081721a3a1..276c0697f131 100644 --- a/src/tir/schedule/primitive/decompose_padding.cc +++ b/src/tir/schedule/primitive/decompose_padding.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../../transforms/ir_utils.h" #include "../utils.h" @@ -531,10 +533,13 @@ bool CanDecomposePadding(ScheduleState self, const StmtSRef& block_sref, /******** FFI ********/ -TVM_FFI_REGISTER_GLOBAL("tir.schedule.CanDecomposePadding") - .set_body_typed([](Schedule self, BlockRV block_rv, LoopRV loop_rv) { - return CanDecomposePadding(self->state(), self->GetSRef(block_rv), self->GetSRef(loop_rv)); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "tir.schedule.CanDecomposePadding", [](Schedule self, BlockRV block_rv, LoopRV loop_rv) { + return CanDecomposePadding(self->state(), self->GetSRef(block_rv), self->GetSRef(loop_rv)); + }); +}); /******** InstructionKind Registration ********/ diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 326d373d6e70..a14c485b7996 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -1344,12 +1346,15 @@ TVM_REGISTER_INST_KIND_TRAITS(DecomposeReductionTraits); /******** FFI ********/ -TVM_FFI_REGISTER_GLOBAL("tir.schedule.RegisterReducer") - .set_body_typed([](int n_buffers, ffi::Function combiner_getter, - ffi::Function identity_getter) { - ReducerRegistry::RegisterReducer(n_buffers, std::move(combiner_getter), - std::move(identity_getter)); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "tir.schedule.RegisterReducer", + [](int n_buffers, ffi::Function combiner_getter, ffi::Function identity_getter) { + ReducerRegistry::RegisterReducer(n_buffers, std::move(combiner_getter), + std::move(identity_getter)); + }); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 7ac47e136983..282e24666680 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -50,267 +50,311 @@ TVM_REGISTER_NODE_TYPE(BlockRVNode); TVM_REGISTER_NODE_TYPE(LoopRVNode); TVM_REGISTER_OBJECT_TYPE(ScheduleNode); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetMod") // - .set_body_method(&ScheduleNode::mod); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") // - .set_body_method(&ScheduleNode::state); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") // - .set_body_method(&ScheduleNode::trace); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetFuncWorkingOn") // - .set_body_method(&ScheduleNode::func_working_on); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // - .set_body_method(&ScheduleNode::Copy); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") // - .set_body_method(&ScheduleNode::Seed); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleForkSeed") // - .set_body_method(&ScheduleNode::ForkSeed); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleWorkOn") // - .set_body_method(&ScheduleNode::WorkOn); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("tir.schedule.ScheduleGetMod", &ScheduleNode::mod) + .def_method("tir.schedule.ScheduleGetState", &ScheduleNode::state) + .def_method("tir.schedule.ScheduleGetTrace", &ScheduleNode::trace) + .def_method("tir.schedule.ScheduleGetFuncWorkingOn", &ScheduleNode::func_working_on) + .def_method("tir.schedule.ScheduleCopy", &ScheduleNode::Copy) + .def_method("tir.schedule.ScheduleSeed", &ScheduleNode::Seed) + .def_method("tir.schedule.ScheduleForkSeed", &ScheduleNode::ForkSeed) + .def_method("tir.schedule.ScheduleWorkOn", &ScheduleNode::WorkOn); +}); /**************** (FFI) Constructor ****************/ -TVM_FFI_REGISTER_GLOBAL("tir.schedule.BlockRV").set_body_typed([]() { return BlockRV(); }); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.LoopRV").set_body_typed([]() { return LoopRV(); }); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") - .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed, - int debug_mask, int error_render_level, bool enable_check) -> Schedule { - return Schedule::Concrete(mod, debug_mask, seed, - static_cast(error_render_level), - enable_check); - }); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.TracedSchedule") - .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed, - int debug_mask, int error_render_level, bool enable_check) -> Schedule { - return Schedule::Traced(mod, seed, debug_mask, - static_cast(error_render_level), - enable_check); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.schedule.BlockRV", []() { return BlockRV(); }) + .def("tir.schedule.LoopRV", []() { return LoopRV(); }) + .def("tir.schedule.ConcreteSchedule", + [](IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, + int error_render_level, bool enable_check) -> Schedule { + return Schedule::Concrete(mod, debug_mask, seed, + static_cast(error_render_level), + enable_check); + }) + .def("tir.schedule.TracedSchedule", + [](IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, + int error_render_level, bool enable_check) -> Schedule { + return Schedule::Traced(mod, seed, debug_mask, + static_cast(error_render_level), + enable_check); + }); +}); /******** (FFI) Lookup random variables ********/ -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGet") - .set_body_typed([](Schedule self, ObjectRef obj) -> ObjectRef { - if (auto loop_rv = obj.as()) { - return self->Get(loop_rv.value()); - } - if (auto block_rv = obj.as()) { - return self->Get(block_rv.value()); - } - if (auto expr_rv = obj.as()) { - return self->Get(expr_rv.value()); - } - LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << obj->GetTypeKey() - << ". Its value is: " << obj; - throw; - }); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetSRef") - .set_body_typed([](Schedule self, ObjectRef obj) -> Optional { - if (auto loop_rv = obj.as()) { - return self->GetSRef(loop_rv.value()); - } - if (auto block_rv = obj.as()) { - return self->GetSRef(block_rv.value()); - } - if (auto stmt = obj.as()) { - return self->GetSRef(stmt.value()); - } - LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey(); - throw; - }); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") - .set_body_typed([](Schedule self, ObjectRef obj) -> void { - if (auto loop_rv = obj.as()) { - return self->RemoveRV(loop_rv.value()); - } - if (auto block_rv = obj.as()) { - return self->RemoveRV(block_rv.value()); - } - if (auto expr_rv = obj.as()) { - return self->RemoveRV(expr_rv.value()); - } - LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey(); - throw; - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.schedule.ScheduleGet", + [](Schedule self, ObjectRef obj) -> ObjectRef { + if (auto loop_rv = obj.as()) { + return self->Get(loop_rv.value()); + } + if (auto block_rv = obj.as()) { + return self->Get(block_rv.value()); + } + if (auto expr_rv = obj.as()) { + return self->Get(expr_rv.value()); + } + LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " + << obj->GetTypeKey() << ". Its value is: " << obj; + throw; + }) + .def("tir.schedule.ScheduleGetSRef", + [](Schedule self, ObjectRef obj) -> Optional { + if (auto loop_rv = obj.as()) { + return self->GetSRef(loop_rv.value()); + } + if (auto block_rv = obj.as()) { + return self->GetSRef(block_rv.value()); + } + if (auto stmt = obj.as()) { + return self->GetSRef(stmt.value()); + } + LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey(); + throw; + }) + .def("tir.schedule.ScheduleRemoveRV", [](Schedule self, ObjectRef obj) -> void { + if (auto loop_rv = obj.as()) { + return self->RemoveRV(loop_rv.value()); + } + if (auto block_rv = obj.as()) { + return self->RemoveRV(block_rv.value()); + } + if (auto expr_rv = obj.as()) { + return self->RemoveRV(expr_rv.value()); + } + LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey(); + throw; + }); +}); /******** (FFI) Sampling ********/ -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical") - .set_body_method(&ScheduleNode::SampleCategorical); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSamplePerfectTile") - .set_body_method(&ScheduleNode::SamplePerfectTile); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSamplePartitionedTile") - .set_body_method(&ScheduleNode::SamplePartitionedTile); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSampleComputeLocation") - .set_body_method(&ScheduleNode::SampleComputeLocation); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("tir.schedule.ScheduleSampleCategorical", &ScheduleNode::SampleCategorical) + .def_method("tir.schedule.ScheduleSamplePerfectTile", &ScheduleNode::SamplePerfectTile) + .def_method("tir.schedule.ScheduleSamplePartitionedTile", + &ScheduleNode::SamplePartitionedTile) + .def_method("tir.schedule.ScheduleSampleComputeLocation", + &ScheduleNode::SampleComputeLocation); +}); /******** (FFI) Get blocks & loops ********/ -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock").set_body_method(&ScheduleNode::GetBlock); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops").set_body_method(&ScheduleNode::GetLoops); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks") - .set_body_typed([](Schedule self, ObjectRef rv) { - if (auto block_rv = rv.as()) { - return self->GetChildBlocks(block_rv.value()); - } - if (auto loop_rv = rv.as()) { - return self->GetChildBlocks(loop_rv.value()); - } - LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() - << ". Its value is: " << rv; - throw; - }); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers") - .set_body_method(&ScheduleNode::GetProducers); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers") - .set_body_method(&ScheduleNode::GetConsumers); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetOutputBlocks") - .set_body_method(&ScheduleNode::GetOutputBlocks); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("tir.schedule.ScheduleGetBlock", &ScheduleNode::GetBlock) + .def_method("tir.schedule.ScheduleGetLoops", &ScheduleNode::GetLoops) + .def("tir.schedule.ScheduleGetChildBlocks", + [](Schedule self, ObjectRef rv) { + if (auto block_rv = rv.as()) { + return self->GetChildBlocks(block_rv.value()); + } + if (auto loop_rv = rv.as()) { + return self->GetChildBlocks(loop_rv.value()); + } + LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " + << rv->GetTypeKey() << ". Its value is: " << rv; + throw; + }) + .def_method("tir.schedule.ScheduleGetProducers", &ScheduleNode::GetProducers) + .def_method("tir.schedule.ScheduleGetConsumers", &ScheduleNode::GetConsumers) + .def_method("tir.schedule.ScheduleGetOutputBlocks", &ScheduleNode::GetOutputBlocks); +}); /******** (FFI) Transform loops ********/ -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleMerge").set_body_method(&ScheduleNode::Merge); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleLoopPartition") - .set_body_method(&ScheduleNode::LoopPartition); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReorder").set_body_method(&ScheduleNode::Reorder); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReorderBlockIterVar") - .set_body_method(&ScheduleNode::ReorderBlockIterVar); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleAddUnitLoop") - .set_body_typed([](Schedule self, ObjectRef rv) -> LoopRV { - if (auto loop_rv = rv.as()) { - return self->AddUnitLoop(loop_rv.value()); - } else if (auto block_rv = rv.as()) { - return self->AddUnitLoop(block_rv.value()); - } else { - LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() - << ". Its value is: " << rv; - throw; - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("tir.schedule.ScheduleMerge", &ScheduleNode::Merge) + .def_method("tir.schedule.ScheduleFuse", &ScheduleNode::Fuse) + .def_method("tir.schedule.ScheduleSplit", &ScheduleNode::Split) + .def_method("tir.schedule.ScheduleLoopPartition", &ScheduleNode::LoopPartition) + .def_method("tir.schedule.ScheduleReorder", &ScheduleNode::Reorder) + .def_method("tir.schedule.ScheduleReorderBlockIterVar", &ScheduleNode::ReorderBlockIterVar) + .def("tir.schedule.ScheduleAddUnitLoop", [](Schedule self, ObjectRef rv) -> LoopRV { + if (auto loop_rv = rv.as()) { + return self->AddUnitLoop(loop_rv.value()); + } else if (auto block_rv = rv.as()) { + return self->AddUnitLoop(block_rv.value()); + } else { + LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " + << rv->GetTypeKey() << ". Its value is: " << rv; + throw; + } + }); +}); /******** (FFI) Manipulate ForKind ********/ -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleParallel").set_body_method(&ScheduleNode::Parallel); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleVectorize").set_body_method(&ScheduleNode::Vectorize); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleBind").set_body_method(&ScheduleNode::Bind); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleUnroll").set_body_method(&ScheduleNode::Unroll); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("tir.schedule.ScheduleParallel", &ScheduleNode::Parallel) + .def_method("tir.schedule.ScheduleVectorize", &ScheduleNode::Vectorize) + .def_method("tir.schedule.ScheduleBind", &ScheduleNode::Bind) + .def_method("tir.schedule.ScheduleUnroll", &ScheduleNode::Unroll); +}); /******** (FFI) Insert cache stages ********/ -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead").set_body_method(&ScheduleNode::CacheRead); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") - .set_body_method(&ScheduleNode::CacheWrite); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReindexCacheRead") - .set_body_method(&ScheduleNode::ReindexCacheRead); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReindexCacheWrite") - .set_body_method(&ScheduleNode::ReindexCacheWrite); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleCacheInplace") - .set_body_method(&ScheduleNode::CacheInplace); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleCacheIndex") - .set_body_method(&ScheduleNode::CacheIndex); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReIndex") - .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, - int buffer_index_type) { - return self->ReIndex(block_rv, buffer_index, static_cast(buffer_index_type)); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("tir.schedule.ScheduleCacheRead", &ScheduleNode::CacheRead) + .def_method("tir.schedule.ScheduleCacheWrite", &ScheduleNode::CacheWrite) + .def_method("tir.schedule.ScheduleReindexCacheRead", &ScheduleNode::ReindexCacheRead) + .def_method("tir.schedule.ScheduleReindexCacheWrite", &ScheduleNode::ReindexCacheWrite) + .def_method("tir.schedule.ScheduleCacheInplace", &ScheduleNode::CacheInplace) + .def_method("tir.schedule.ScheduleCacheIndex", &ScheduleNode::CacheIndex) + .def("tir.schedule.ScheduleReIndex", + [](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type) { + return self->ReIndex(block_rv, buffer_index, + static_cast(buffer_index_type)); + }); +}); /******** (FFI) Data movement ********/ -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReadAt").set_body_method(&ScheduleNode::ReadAt); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleWriteAt").set_body_method(&ScheduleNode::WriteAt); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("tir.schedule.ScheduleReadAt", &ScheduleNode::ReadAt) + .def_method("tir.schedule.ScheduleWriteAt", &ScheduleNode::WriteAt); +}); /******** (FFI) Compute location ********/ -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt").set_body_method(&ScheduleNode::ComputeAt); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeAt") - .set_body_method(&ScheduleNode::ReverseComputeAt); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline") - .set_body_method(&ScheduleNode::ComputeInline); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline") - .set_body_method(&ScheduleNode::ReverseComputeInline); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("tir.schedule.ScheduleComputeAt", &ScheduleNode::ComputeAt) + .def_method("tir.schedule.ScheduleReverseComputeAt", &ScheduleNode::ReverseComputeAt) + .def_method("tir.schedule.ScheduleComputeInline", &ScheduleNode::ComputeInline) + .def_method("tir.schedule.ScheduleReverseComputeInline", &ScheduleNode::ReverseComputeInline); +}); /******** (FFI) Reduction ********/ -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposeReduction") - .set_body_method(&ScheduleNode::DecomposeReduction); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor").set_body_method(&ScheduleNode::RFactor); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("tir.schedule.ScheduleDecomposeReduction", &ScheduleNode::DecomposeReduction) + .def_method("tir.schedule.ScheduleRFactor", &ScheduleNode::RFactor); +}); /******** (FFI) Block annotation ********/ -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign") - .set_body_method(&ScheduleNode::StorageAlign); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope").set_body_method(&ScheduleNode::SetScope); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeSetDType") - .set_body_method(&ScheduleNode::UnsafeSetDType); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("tir.schedule.ScheduleStorageAlign", &ScheduleNode::StorageAlign) + .def_method("tir.schedule.ScheduleSetScope", &ScheduleNode::SetScope) + .def_method("tir.schedule.ScheduleUnsafeSetDType", &ScheduleNode::UnsafeSetDType); +}); /******** (FFI) Blockize & Tensorize ********/ -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize") - .set_body_typed([](Schedule self, ObjectRef target, bool preserve_unit_iters) { - if (auto loop_rv = target.as()) { - return self->Blockize(loop_rv.value(), preserve_unit_iters); - } else if (auto blocks = target.as>()) { - return self->Blockize(blocks.value(), preserve_unit_iters); - } - LOG(FATAL) << "Unsupported target type: " << target->GetTypeKey(); - }); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize") - .set_body_typed([](Schedule self, ObjectRef rv, String intrin, bool preserve_unit_iters) { - if (auto block_rv = rv.as()) { - self->Tensorize(block_rv.value(), intrin, preserve_unit_iters); - } else if (auto loop_rv = rv.as()) { - self->Tensorize(loop_rv.value(), intrin, preserve_unit_iters); - } else { - LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() - << ". Its value is: " << rv; - } - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.schedule.ScheduleBlockize", + [](Schedule self, ObjectRef target, bool preserve_unit_iters) { + if (auto loop_rv = target.as()) { + return self->Blockize(loop_rv.value(), preserve_unit_iters); + } else if (auto blocks = target.as>()) { + return self->Blockize(blocks.value(), preserve_unit_iters); + } + LOG(FATAL) << "Unsupported target type: " << target->GetTypeKey(); + }) + .def("tir.schedule.ScheduleTensorize", + [](Schedule self, ObjectRef rv, String intrin, bool preserve_unit_iters) { + if (auto block_rv = rv.as()) { + self->Tensorize(block_rv.value(), intrin, preserve_unit_iters); + } else if (auto loop_rv = rv.as()) { + self->Tensorize(loop_rv.value(), intrin, preserve_unit_iters); + } else { + LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " + << rv->GetTypeKey() << ". Its value is: " << rv; + } + }); +}); /******** (FFI) Annotation ********/ -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate") - .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key, const Any& ann_val) { - if (auto block_rv = rv.as()) { - return self->Annotate(block_rv.value(), ann_key, ann_val); - } - if (auto loop_rv = rv.as()) { - return self->Annotate(loop_rv.value(), ann_key, ann_val); - } - LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() - << ". Its value is: " << rv; - throw; - }); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate") - .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key) { - if (auto block_rv = rv.as()) { - return self->Unannotate(block_rv.value(), ann_key); - } - if (auto loop_rv = rv.as()) { - return self->Unannotate(loop_rv.value(), ann_key); - } - LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() - << ". Its value is: " << rv; - throw; - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.schedule.ScheduleAnnotate", + [](Schedule self, ObjectRef rv, const String& ann_key, const Any& ann_val) { + if (auto block_rv = rv.as()) { + return self->Annotate(block_rv.value(), ann_key, ann_val); + } + if (auto loop_rv = rv.as()) { + return self->Annotate(loop_rv.value(), ann_key, ann_val); + } + LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " + << rv->GetTypeKey() << ". Its value is: " << rv; + throw; + }) + .def("tir.schedule.ScheduleUnannotate", [](Schedule self, ObjectRef rv, + const String& ann_key) { + if (auto block_rv = rv.as()) { + return self->Unannotate(block_rv.value(), ann_key); + } + if (auto loop_rv = rv.as()) { + return self->Unannotate(loop_rv.value(), ann_key); + } + LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() + << ". Its value is: " << rv; + throw; + }); +}); /******** (FFI) Layout transformation ********/ -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout") - .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, - int buffer_index_type, const IndexMap& index_map, - const Optional& pad_value, bool assume_injective_transform) { - return self->TransformLayout(block_rv, buffer_index, - static_cast(buffer_index_type), index_map, - pad_value, assume_injective_transform); - }); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleTransformBlockLayout") - .set_body_method(&ScheduleNode::TransformBlockLayout); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSetAxisSeparator") - .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, - int buffer_index_type, const Array& axis_separators) { - return self->SetAxisSeparator( - block_rv, buffer_index, static_cast(buffer_index_type), axis_separators); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.schedule.ScheduleTransformLayout", + [](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type, + const IndexMap& index_map, const Optional& pad_value, + bool assume_injective_transform) { + return self->TransformLayout(block_rv, buffer_index, + static_cast(buffer_index_type), + index_map, pad_value, assume_injective_transform); + }) + .def_method("tir.schedule.ScheduleTransformBlockLayout", &ScheduleNode::TransformBlockLayout) + .def("tir.schedule.ScheduleSetAxisSeparator", + [](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type, + const Array& axis_separators) { + return self->SetAxisSeparator(block_rv, buffer_index, + static_cast(buffer_index_type), + axis_separators); + }); +}); /******** (FFI) Padding decomposition ********/ -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposePadding") - .set_body_method(&ScheduleNode::DecomposePadding); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.SchedulePadEinsum").set_body_method(&ScheduleNode::PadEinsum); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("tir.schedule.ScheduleDecomposePadding", &ScheduleNode::DecomposePadding) + .def_method("tir.schedule.SchedulePadEinsum", &ScheduleNode::PadEinsum); +}); /******** (FFI) Buffer transformation ********/ -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleRollingBuffer") - .set_body_method(&ScheduleNode::RollingBuffer); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_method("tir.schedule.ScheduleRollingBuffer", &ScheduleNode::RollingBuffer); +}); /******** (FFI) Misc ********/ -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc") - .set_body_method(&ScheduleNode::EnterPostproc); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeHideBufferAccess") - .set_body_method(&ScheduleNode::UnsafeHideBufferAccess); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_method("tir.schedule.ScheduleEnterPostproc", &ScheduleNode::EnterPostproc) + .def_method("tir.schedule.ScheduleUnsafeHideBufferAccess", + &ScheduleNode::UnsafeHideBufferAccess); +}); /******** (FFI) Annotate buffer access ********/ -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotateBufferAccess") - .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, - int buffer_index_type, const IndexMap& index_map) { - return self->AnnotateBufferAccess(block_rv, buffer_index, - static_cast(buffer_index_type), index_map); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.schedule.ScheduleAnnotateBufferAccess", + [](Schedule self, const BlockRV& block_rv, int buffer_index, + int buffer_index_type, const IndexMap& index_map) { + return self->AnnotateBufferAccess( + block_rv, buffer_index, + static_cast(buffer_index_type), index_map); + }); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 7bda23f1df70..2d2390e60c21 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include "./utils.h" namespace tvm { @@ -1014,20 +1015,22 @@ TVM_DLL Array GetCachedFlags(const ScheduleState& self, const StmtSRef& bl /**************** FFI ****************/ TVM_REGISTER_NODE_TYPE(ScheduleStateNode); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleState") - .set_body_typed([](IRModule mod, int debug_mask, bool enable_check) -> ScheduleState { - return ScheduleState(mod, debug_mask, enable_check); - }); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetBlockScope") - .set_body_method(&ScheduleStateNode::GetBlockScope); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleStateReplace") - .set_body_method(&ScheduleStateNode::Replace); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetSRef") - .set_body_typed([](ScheduleState self, Stmt stmt) -> Optional { - auto it = self->stmt2ref.find(stmt.get()); - return it != self->stmt2ref.end() ? it->second : Optional(std::nullopt); - }); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetCachedFlags").set_body_typed(GetCachedFlags); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.schedule.ScheduleState", + [](IRModule mod, int debug_mask, bool enable_check) -> ScheduleState { + return ScheduleState(mod, debug_mask, enable_check); + }) + .def_method("tir.schedule.ScheduleStateGetBlockScope", &ScheduleStateNode::GetBlockScope) + .def_method("tir.schedule.ScheduleStateReplace", &ScheduleStateNode::Replace) + .def("tir.schedule.ScheduleStateGetSRef", + [](ScheduleState self, Stmt stmt) -> Optional { + auto it = self->stmt2ref.find(stmt.get()); + return it != self->stmt2ref.end() ? it->second : Optional(std::nullopt); + }) + .def("tir.schedule.ScheduleStateGetCachedFlags", GetCachedFlags); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index 574ee3aab625..605bc53d127a 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -567,29 +567,30 @@ TVM_REGISTER_INST_KIND_TRAITS(EnterPostprocTraits); /**************** FFI ****************/ TVM_REGISTER_NODE_TYPE(TraceNode); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.Trace") - .set_body_typed([](Optional> insts, - Optional> decisions) { - return Trace(insts.value_or(Array()), decisions.value_or({})); - }); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceGetDecision").set_body_method(&TraceNode::GetDecision); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceAppend") - .set_body_typed([](Trace self, Instruction inst, Optional decision) { - if (decision.defined()) { - return self->Append(inst, decision.value()); - } else { - return self->Append(inst); - } - }); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.TracePop").set_body_method(&TraceNode::Pop); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceApplyToSchedule") - .set_body_method(&TraceNode::ApplyToSchedule); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceAsJSON").set_body_method(&TraceNode::AsJSON); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceAsPython").set_body_method(&TraceNode::AsPython); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceWithDecision").set_body_method(&TraceNode::WithDecision); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceSimplified").set_body_method(&TraceNode::Simplified); -TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceApplyJSONToSchedule") - .set_body_typed(Trace::ApplyJSONToSchedule); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.schedule.Trace", + [](Optional> insts, Optional> decisions) { + return Trace(insts.value_or(Array()), decisions.value_or({})); + }) + .def_method("tir.schedule.TraceGetDecision", &TraceNode::GetDecision) + .def("tir.schedule.TraceAppend", + [](Trace self, Instruction inst, Optional decision) { + if (decision.defined()) { + return self->Append(inst, decision.value()); + } else { + return self->Append(inst); + } + }) + .def_method("tir.schedule.TracePop", &TraceNode::Pop) + .def_method("tir.schedule.TraceApplyToSchedule", &TraceNode::ApplyToSchedule) + .def_method("tir.schedule.TraceAsJSON", &TraceNode::AsJSON) + .def_method("tir.schedule.TraceAsPython", &TraceNode::AsPython) + .def_method("tir.schedule.TraceWithDecision", &TraceNode::WithDecision) + .def_method("tir.schedule.TraceSimplified", &TraceNode::Simplified) + .def("tir.schedule.TraceApplyJSONToSchedule", Trace::ApplyJSONToSchedule); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index fc284fec20e6..8e402194ecef 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -17,6 +17,8 @@ * under the License. */ +#include + #include "../transforms/ir_utils.h" #include "./utils.h" @@ -439,7 +441,10 @@ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block return reorder_suffix[0]; } -TVM_FFI_REGISTER_GLOBAL("tir.schedule.TileWithTensorIntrin").set_body_typed(TileWithTensorIntrin); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.schedule.TileWithTensorIntrin", TileWithTensorIntrin); +}); /******** BlockBufferAccessSimplifier ********/ void BlockBufferAccessSimplifier::SimplifyAccessRegion(Array* old_access_regions) { @@ -557,7 +562,10 @@ Optional NormalizePrimFunc(Schedule sch) { return Array{leaf_blocks, block_loops, block_iters, block_is_reduction}; } -TVM_FFI_REGISTER_GLOBAL("tir.schedule.NormalizePrimFunc").set_body_typed(NormalizePrimFunc); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.schedule.NormalizePrimFunc", NormalizePrimFunc); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/annotate_device_regions.cc b/src/tir/transforms/annotate_device_regions.cc index f8adcf4f5010..77e0c80713c6 100644 --- a/src/tir/transforms/annotate_device_regions.cc +++ b/src/tir/transforms/annotate_device_regions.cc @@ -22,6 +22,7 @@ * \brief Split device function from host. */ #include +#include #include #include #include @@ -74,8 +75,10 @@ Pass AnnotateDeviceRegions() { return CreatePrimFuncPass(pass_func, 0, "tir.AnnotateDeviceRegions", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.AnnotateDeviceRegions") - .set_body_typed(AnnotateDeviceRegions); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.AnnotateDeviceRegions", AnnotateDeviceRegions); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/bind_target.cc b/src/tir/transforms/bind_target.cc index 7cb010aa9bc6..fc8e6ede2bbc 100644 --- a/src/tir/transforms/bind_target.cc +++ b/src/tir/transforms/bind_target.cc @@ -34,6 +34,7 @@ * with appropriate targets and updates call sites accordingly. */ +#include #include #include #include @@ -370,7 +371,10 @@ transform::Pass BindTarget(Target target) { return tir::transform::CreateModulePass(fpass, 0, "tir.BindTarget", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.BindTarget").set_body_typed(BindTarget); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.BindTarget", BindTarget); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc index 15728e846224..b5d5e70feee9 100644 --- a/src/tir/transforms/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -255,8 +256,10 @@ Pass InstrumentBoundCheckers() { return CreatePrimFuncPass(pass_func, 0, "tir.InstrumentBoundCheckers", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.InstrumentBoundCheckers") - .set_body_typed(InstrumentBoundCheckers); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.InstrumentBoundCheckers", InstrumentBoundCheckers); +}); } // namespace transform diff --git a/src/tir/transforms/combine_context_call.cc b/src/tir/transforms/combine_context_call.cc index 38fe86a9e2ac..15e7dbcb408e 100644 --- a/src/tir/transforms/combine_context_call.cc +++ b/src/tir/transforms/combine_context_call.cc @@ -23,6 +23,7 @@ * \file combine_context_call.cc */ #include +#include #include #include #include @@ -112,7 +113,10 @@ Pass CombineContextCall() { return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.CombineContextCall").set_body_typed(CombineContextCall); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.CombineContextCall", CombineContextCall); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/common_subexpr_elim.cc b/src/tir/transforms/common_subexpr_elim.cc index 3fd78a523301..3f89f82d099b 100644 --- a/src/tir/transforms/common_subexpr_elim.cc +++ b/src/tir/transforms/common_subexpr_elim.cc @@ -30,6 +30,7 @@ #include "common_subexpr_elim.h" #include +#include #include #include // For the class Pass and the class PassContext #include // For the analysis which gives the size of an expr @@ -637,7 +638,10 @@ Pass CommonSubexprElimTIR(bool enable_cse_tir, bool identify_equiv_terms) { } // The pass can now be invoked via the pass infrastructure, but we also add a Python binding for it -TVM_FFI_REGISTER_GLOBAL("tir.transform.CommonSubexprElimTIR").set_body_typed(CommonSubexprElimTIR); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.CommonSubexprElimTIR", CommonSubexprElimTIR); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index b75d73bc0f5f..3d781664681f 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -756,8 +757,10 @@ Pass CompactBufferAllocation(bool is_strict) { return CreatePrimFuncPass(pass_func, 0, "tir.CompactBufferAllocation", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.CompactBufferAllocation") - .set_body_typed(CompactBufferAllocation); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.CompactBufferAllocation", CompactBufferAllocation); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc b/src/tir/transforms/convert_blocks_to_opaque.cc index 09c2762efab5..434a9c495495 100644 --- a/src/tir/transforms/convert_blocks_to_opaque.cc +++ b/src/tir/transforms/convert_blocks_to_opaque.cc @@ -22,6 +22,7 @@ * \brief Convert the blocks to opaque blocks which do not have block vars. */ +#include #include #include @@ -122,8 +123,10 @@ Pass ConvertBlocksToOpaque() { return CreatePrimFuncPass(pass_func, 0, "tir.ConvertBlocksToOpaque", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.ConvertBlocksToOpaque") - .set_body_typed(ConvertBlocksToOpaque); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.ConvertBlocksToOpaque", ConvertBlocksToOpaque); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/convert_for_loops_serial.cc b/src/tir/transforms/convert_for_loops_serial.cc index 4c992163df04..e4827bdc51b2 100644 --- a/src/tir/transforms/convert_for_loops_serial.cc +++ b/src/tir/transforms/convert_for_loops_serial.cc @@ -22,6 +22,7 @@ * \brief Convert all for loops to serial for lesser memory consumption */ #include +#include #include #include #include @@ -66,8 +67,10 @@ Pass ConvertForLoopsToSerial() { return CreatePrimFuncPass(pass_func, 0, "tir.ConvertForLoopsToSerial", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.ConvertForLoopsToSerial") - .set_body_typed(ConvertForLoopsToSerial); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.ConvertForLoopsToSerial", ConvertForLoopsToSerial); +}); } // namespace transform diff --git a/src/tir/transforms/decorate_device_scope.cc b/src/tir/transforms/decorate_device_scope.cc index 3b382850559a..5267e76d4db8 100644 --- a/src/tir/transforms/decorate_device_scope.cc +++ b/src/tir/transforms/decorate_device_scope.cc @@ -21,6 +21,7 @@ * \file decorate_device_scope.cc */ #include +#include #include #include #include @@ -44,7 +45,10 @@ Pass DecorateDeviceScope() { return CreatePrimFuncPass(pass_func, 0, "tir.DecorateDeviceScope", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.DecorateDeviceScope").set_body_typed(DecorateDeviceScope); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.DecorateDeviceScope", DecorateDeviceScope); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/default_gpu_schedule.cc b/src/tir/transforms/default_gpu_schedule.cc index 398b00092d08..2ac1b1047bd0 100644 --- a/src/tir/transforms/default_gpu_schedule.cc +++ b/src/tir/transforms/default_gpu_schedule.cc @@ -17,6 +17,8 @@ * under the License. */ +#include + #include "../../meta_schedule/utils.h" namespace tvm { @@ -162,7 +164,10 @@ Pass DefaultGPUSchedule() { /*required=*/{}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.DefaultGPUSchedule").set_body_typed(DefaultGPUSchedule); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.DefaultGPUSchedule", DefaultGPUSchedule); +}); } // namespace transform diff --git a/src/tir/transforms/extract_constants.cc b/src/tir/transforms/extract_constants.cc index 509efb8d06fd..3d1cb15d7554 100644 --- a/src/tir/transforms/extract_constants.cc +++ b/src/tir/transforms/extract_constants.cc @@ -26,6 +26,7 @@ */ #include #include +#include #include #include @@ -105,8 +106,10 @@ tvm::transform::Pass ExtractPrimFuncConstants() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.ExtractPrimFuncConstants", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.ExtractPrimFuncConstants") - .set_body_typed(ExtractPrimFuncConstants); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.ExtractPrimFuncConstants", ExtractPrimFuncConstants); +}); } // namespace transform diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index a0c39c8fcb68..8a9470093f62 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -22,6 +22,7 @@ */ #include +#include #include #include #include @@ -279,7 +280,10 @@ Pass FlattenBuffer() { return CreatePrimFuncPass(pass_func, 0, "tir.FlattenBuffer", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.FlattenBuffer").set_body_typed(FlattenBuffer); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.FlattenBuffer", FlattenBuffer); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/force_narrow_index_to_i32.cc b/src/tir/transforms/force_narrow_index_to_i32.cc index bd33e564e5c2..2e2833e5ca8e 100644 --- a/src/tir/transforms/force_narrow_index_to_i32.cc +++ b/src/tir/transforms/force_narrow_index_to_i32.cc @@ -23,6 +23,7 @@ * \note This pass is not used in default cases. */ +#include #include #include #include @@ -86,8 +87,10 @@ Pass ForceNarrowIndexToInt32() { return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.ForceNarrowIndexToInt32") - .set_body_typed(ForceNarrowIndexToInt32); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.ForceNarrowIndexToInt32", ForceNarrowIndexToInt32); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/hoist_expression.cc b/src/tir/transforms/hoist_expression.cc index 326dbde39de3..57a724727178 100644 --- a/src/tir/transforms/hoist_expression.cc +++ b/src/tir/transforms/hoist_expression.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include #include @@ -563,7 +564,10 @@ Pass HoistExpression() { "tir.HoistExpression"); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.HoistExpression").set_body_typed(HoistExpression); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.HoistExpression", HoistExpression); +}); Pass HoistIfThenElse() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { @@ -598,7 +602,10 @@ Pass HoistIfThenElse() { "tir.HoistIfThenElse"); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.HoistIfThenElse").set_body_typed(HoistIfThenElse); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.HoistIfThenElse", HoistIfThenElse); +}); Pass HoistIfThenElseBasic() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { @@ -618,7 +625,10 @@ Pass HoistIfThenElseBasic() { "tir.HoistIfThenElseBasic"); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.HoistIfThenElseBasic").set_body_typed(HoistIfThenElseBasic); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.HoistIfThenElseBasic", HoistIfThenElseBasic); +}); } // namespace transform diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 8c4e526e5175..c28bb8933a5b 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -22,6 +22,7 @@ * \file inject_double_buffer.cc */ #include +#include #include #include #include @@ -327,7 +328,10 @@ Pass InjectDoubleBuffer() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectDoubleBuffer", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectDoubleBuffer").set_body_typed(InjectDoubleBuffer); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.InjectDoubleBuffer", InjectDoubleBuffer); +}); } // namespace transform diff --git a/src/tir/transforms/inject_permuted_layout.cc b/src/tir/transforms/inject_permuted_layout.cc index 00e29061ba3a..26aee259359f 100644 --- a/src/tir/transforms/inject_permuted_layout.cc +++ b/src/tir/transforms/inject_permuted_layout.cc @@ -22,6 +22,7 @@ * \brief The pass injects permuted layout for shared memory buffers to avoid bank conflicts. */ #include +#include #include #include #include @@ -295,7 +296,10 @@ Pass InjectPermutedLayout() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectPermutedLayout", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectPermutedLayout").set_body_typed(InjectPermutedLayout); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.InjectPermutedLayout", InjectPermutedLayout); +}); } // namespace transform diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index 04bcecac36b0..75c8e2ec46ed 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -21,6 +21,7 @@ * \brief Replace copy from global to shared with async copy * \file inject_ptx_async_copy.cc */ +#include #include #include #include @@ -199,7 +200,10 @@ Pass InjectPTXAsyncCopy() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectPTXAsyncCopy", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectPTXAsyncCopy").set_body_typed(InjectPTXAsyncCopy); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.InjectPTXAsyncCopy", InjectPTXAsyncCopy); +}); } // namespace transform diff --git a/src/tir/transforms/inject_ptx_ldg32.cc b/src/tir/transforms/inject_ptx_ldg32.cc index c3a6cf50b828..37a3e933232a 100644 --- a/src/tir/transforms/inject_ptx_ldg32.cc +++ b/src/tir/transforms/inject_ptx_ldg32.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -123,7 +124,10 @@ Pass InjectPTXLDG32(bool enable_inject_ptx_intrin) { // The pass can now be invoked via the pass infrastructure, but we also add a // Python binding for it -TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectPTXLDG32").set_body_typed(InjectPTXLDG32); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.InjectPTXLDG32", InjectPTXLDG32); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/inject_rolling_buffer.cc b/src/tir/transforms/inject_rolling_buffer.cc index ed35bdb0655f..a1d84456eb0a 100644 --- a/src/tir/transforms/inject_rolling_buffer.cc +++ b/src/tir/transforms/inject_rolling_buffer.cc @@ -35,6 +35,7 @@ */ #include #include +#include #include #include @@ -315,7 +316,10 @@ Pass InjectRollingBuffer() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectRollingBuffer", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectRollingBuffer").set_body_typed(InjectRollingBuffer); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.InjectRollingBuffer", InjectRollingBuffer); +}); } // namespace transform diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index d5f69315b149..0a79bc63cd6b 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -21,6 +21,7 @@ * \file inject_software_pipeline.cc * \brief Transform annotated loops into pipelined one that parallelize producers and consumers */ +#include #include #include #include @@ -1259,8 +1260,10 @@ Pass InjectSoftwarePipeline() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectSoftwarePipeline", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectSoftwarePipeline") - .set_body_typed(InjectSoftwarePipeline); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.InjectSoftwarePipeline", InjectSoftwarePipeline); +}); } // namespace transform diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index d54b7d3953cd..12b332feb916 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -21,6 +21,7 @@ * \file inject_virtual_thread.cc */ #include +#include #include #include #include @@ -523,7 +524,10 @@ Pass InjectVirtualThread() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectVirtualThread", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectVirtualThread").set_body_typed(InjectVirtualThread); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.InjectVirtualThread", InjectVirtualThread); +}); } // namespace transform diff --git a/src/tir/transforms/inline_private_functions.cc b/src/tir/transforms/inline_private_functions.cc index eae2e29ef686..89ee9e7c7db3 100644 --- a/src/tir/transforms/inline_private_functions.cc +++ b/src/tir/transforms/inline_private_functions.cc @@ -22,6 +22,7 @@ * \brief Inline private functions to their callsite */ #include +#include #include #include #include @@ -292,8 +293,10 @@ Pass InlinePrivateFunctions() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.InlinePrivateFunctions", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.InlinePrivateFunctions") - .set_body_typed(InlinePrivateFunctions); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.InlinePrivateFunctions", InlinePrivateFunctions); +}); } // namespace transform diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 72ee656d1bac..37ea799df950 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -25,6 +25,7 @@ #include #include +#include #include #include @@ -850,7 +851,10 @@ Pass ConvertSSA() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.ConvertSSA", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.ConvertSSA").set_body_typed(ConvertSSA); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.ConvertSSA", ConvertSSA); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lift_thread_binding.cc b/src/tir/transforms/lift_thread_binding.cc index b30a47c84fe9..f9493abbf1ea 100644 --- a/src/tir/transforms/lift_thread_binding.cc +++ b/src/tir/transforms/lift_thread_binding.cc @@ -22,6 +22,7 @@ * \brief Convert the blocks to opaque blocks which do not have block vars. */ +#include #include #include @@ -183,7 +184,10 @@ Pass LiftThreadBinding() { return CreatePrimFuncPass(pass_func, 0, "tir.LiftThreadBinding", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.LiftThreadBinding").set_body_typed(LiftThreadBinding); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.LiftThreadBinding", LiftThreadBinding); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index 324c5819b36b..5a55d4149898 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -819,7 +820,10 @@ Pass LoopPartition() { return CreatePrimFuncPass(pass_func, 0, "tir.LoopPartition", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.LoopPartition").set_body_typed(LoopPartition); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.LoopPartition", LoopPartition); +}); } // namespace transform diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index c3358e1c9207..bbde22e4800e 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -175,7 +176,10 @@ Pass LowerAsyncDMA() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerAsyncDMA", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerAsyncDMA").set_body_typed(LowerAsyncDMA); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.LowerAsyncDMA", LowerAsyncDMA); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index a10937a2b7c9..7054bc77fa06 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -21,6 +21,7 @@ * \file lower_cross_thread_reduction.cc */ #include +#include #include #include #include @@ -935,8 +936,10 @@ Pass LowerCrossThreadReduction() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerCrossThreadReduction", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerCrossThreadReduction") - .set_body_typed(LowerCrossThreadReduction); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.LowerCrossThreadReduction", LowerCrossThreadReduction); +}); } // namespace transform diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index e0863b865d15..b651c39bb899 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -22,6 +22,7 @@ */ #include +#include #include #include #include @@ -250,7 +251,10 @@ Pass LowerCustomDatatypes() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerCustomDatatypes").set_body_typed(LowerCustomDatatypes); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.LowerCustomDatatypes", LowerCustomDatatypes); +}); } // namespace transform diff --git a/src/tir/transforms/lower_device_kernel_launch.cc b/src/tir/transforms/lower_device_kernel_launch.cc index c32c0c3debf3..1774d97f64f1 100644 --- a/src/tir/transforms/lower_device_kernel_launch.cc +++ b/src/tir/transforms/lower_device_kernel_launch.cc @@ -22,6 +22,7 @@ * \brief Split device function from host. */ #include +#include #include #include #include @@ -369,8 +370,10 @@ Pass LowerDeviceKernelLaunch() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.LowerDeviceKernelLaunch", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerDeviceKernelLaunch") - .set_body_typed(LowerDeviceKernelLaunch); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.LowerDeviceKernelLaunch", LowerDeviceKernelLaunch); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index fe5e0389676d..7e0849100d1c 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -130,8 +131,10 @@ Pass LowerDeviceStorageAccessInfo() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerDeviceStorageAccessInfo") - .set_body_typed(LowerDeviceStorageAccessInfo); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.LowerDeviceStorageAccessInfo", LowerDeviceStorageAccessInfo); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_init_block.cc b/src/tir/transforms/lower_init_block.cc index 03188fb6c907..188d856f9fb8 100644 --- a/src/tir/transforms/lower_init_block.cc +++ b/src/tir/transforms/lower_init_block.cc @@ -21,6 +21,7 @@ * Lower block init stmt into branch stmt * \file lower_reduction.cc */ +#include #include #include #include @@ -79,7 +80,10 @@ Pass LowerInitBlock() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerInitBlock", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerInitBlock").set_body_typed(LowerInitBlock); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.LowerInitBlock", LowerInitBlock); +}); } // namespace transform diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 8fe9bedce9f0..dab0e672a60b 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -22,6 +22,7 @@ * \file lower_intrin.cc */ #include +#include #include #include #include @@ -394,7 +395,10 @@ Pass LowerIntrin() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerIntrin").set_body_typed(LowerIntrin); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.LowerIntrin", LowerIntrin); +}); } // namespace transform diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index e0cb7cf80acc..ea1055896ade 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -267,7 +268,10 @@ Pass LowerMatchBuffer() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerMatchBuffer", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerMatchBuffer").set_body_typed(LowerMatchBuffer); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.LowerMatchBuffer", LowerMatchBuffer); +}); } // namespace transform diff --git a/src/tir/transforms/lower_opaque_block.cc b/src/tir/transforms/lower_opaque_block.cc index 64de12263c3e..d3fe46b5d1e8 100644 --- a/src/tir/transforms/lower_opaque_block.cc +++ b/src/tir/transforms/lower_opaque_block.cc @@ -21,6 +21,7 @@ * \file lower_opaque_block.cc */ +#include #include #include @@ -214,7 +215,10 @@ Pass LowerOpaqueBlock() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerOpaqueBlock", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerOpaqueBlock").set_body_typed(LowerOpaqueBlock); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.LowerOpaqueBlock", LowerOpaqueBlock); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 81023d5471f3..ef9e9f3b87b6 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -809,7 +810,10 @@ Pass LowerThreadAllreduce() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerThreadAllreduce").set_body_typed(LowerThreadAllreduce); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.LowerThreadAllreduce", LowerThreadAllreduce); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 095bd321c937..d364092ec9e4 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -22,6 +22,7 @@ * \file tir/transforms/lower_tvm_buildin.cc */ #include +#include #include #include #include @@ -673,7 +674,10 @@ Pass LowerTVMBuiltin() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerTVMBuiltin").set_body_typed(LowerTVMBuiltin); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.LowerTVMBuiltin", LowerTVMBuiltin); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_vtcm_alloc.cc b/src/tir/transforms/lower_vtcm_alloc.cc index eac2a21b4917..f5be570aa312 100644 --- a/src/tir/transforms/lower_vtcm_alloc.cc +++ b/src/tir/transforms/lower_vtcm_alloc.cc @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include #include @@ -72,7 +73,10 @@ Pass LowerVtcmAlloc() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerVtcmAlloc", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerVtcmAlloc").set_body_typed(LowerVtcmAlloc); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.LowerVtcmAlloc", LowerVtcmAlloc); +}); } // namespace transform diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 0cf6f9d152d1..e2bef4328622 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -461,7 +462,10 @@ Pass LowerWarpMemory() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerWarpMemory").set_body_typed(LowerWarpMemory); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.LowerWarpMemory", LowerWarpMemory); +}); } // namespace transform diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 48cf7bbad1f5..258db2afad10 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -21,6 +21,7 @@ * \file make_packed_api.cc Lower PrimFunc to use the packed function API. */ #include +#include #include #include #include @@ -438,8 +439,9 @@ Pass MakePackedAPI() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakePackedAPI", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.MakePackedAPI").set_body_typed([]() { - return MakePackedAPI(); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.MakePackedAPI", []() { return MakePackedAPI(); }); }); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 989b39f0e370..3357dc35c0e1 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -21,6 +21,7 @@ * \file make_unpacked_api.cc Lower PrimFunc to a standard C function API. */ #include +#include #include #include #include @@ -200,7 +201,10 @@ Pass MakeUnpackedAPI() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakeUnpackedAPI", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.MakeUnpackedAPI").set_body_typed(MakeUnpackedAPI); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.MakeUnpackedAPI", MakeUnpackedAPI); +}); } // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/manifest_shared_memory_local_stage.cc b/src/tir/transforms/manifest_shared_memory_local_stage.cc index 7c9fe2b9aacd..c4405844d9e0 100644 --- a/src/tir/transforms/manifest_shared_memory_local_stage.cc +++ b/src/tir/transforms/manifest_shared_memory_local_stage.cc @@ -27,6 +27,7 @@ * of requiring buffer access to be contiguous in each dimension. */ #include +#include #include #include #include @@ -275,8 +276,11 @@ Pass ManifestSharedMemoryLocalStage() { return CreatePrimFuncPass(pass_func, 0, "tir.ManifestSharedMemoryLocalStage", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.ManifestSharedMemoryLocalStage") - .set_body_typed(ManifestSharedMemoryLocalStage); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.ManifestSharedMemoryLocalStage", + ManifestSharedMemoryLocalStage); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/memhammer_lower_auto_copy.cc b/src/tir/transforms/memhammer_lower_auto_copy.cc index 334a44df069c..64526b1715c3 100644 --- a/src/tir/transforms/memhammer_lower_auto_copy.cc +++ b/src/tir/transforms/memhammer_lower_auto_copy.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -776,7 +777,10 @@ Pass LowerAutoCopy() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerAutoCopy", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerAutoCopy").set_body_typed(LowerAutoCopy); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.LowerAutoCopy", LowerAutoCopy); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/merge_shared_memory_allocations.cc b/src/tir/transforms/merge_shared_memory_allocations.cc index eaf0aab391ec..09bf8b113c45 100644 --- a/src/tir/transforms/merge_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_shared_memory_allocations.cc @@ -24,6 +24,7 @@ * allocation. */ #include +#include #include #include #include @@ -695,8 +696,10 @@ Pass MergeSharedMemoryAllocations() { return CreatePrimFuncPass(pass_func, 0, "tir.MergeSharedMemoryAllocations", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.MergeSharedMemoryAllocations") - .set_body_typed(MergeSharedMemoryAllocations); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.MergeSharedMemoryAllocations", MergeSharedMemoryAllocations); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 8183b2fd8f45..768e24e2696a 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -320,7 +321,10 @@ Pass NarrowDataType(int target_bits) { return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.NarrowDataType").set_body_typed(NarrowDataType); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.NarrowDataType", NarrowDataType); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index f4547a57581a..fea7ae9d4eaf 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -22,6 +22,7 @@ * \file plan_update_buffer_allocation_location.cc */ +#include #include #include #include @@ -257,8 +258,11 @@ Pass PlanAndUpdateBufferAllocationLocation() { return CreatePrimFuncPass(pass_func, 0, "tir.PlanAndUpdateBufferAllocationLocation", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.PlanAndUpdateBufferAllocationLocation") - .set_body_typed(PlanAndUpdateBufferAllocationLocation); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.PlanAndUpdateBufferAllocationLocation", + PlanAndUpdateBufferAllocationLocation); +}); } // namespace transform diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc index 00751e3a9ad8..3d12a849d029 100644 --- a/src/tir/transforms/primfunc_utils.cc +++ b/src/tir/transforms/primfunc_utils.cc @@ -22,6 +22,7 @@ * \brief Passes that serve as helper functions. */ +#include #include namespace tvm { @@ -78,8 +79,12 @@ transform::Pass Filter(ffi::TypedFunction fcond) { return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.Filter", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.AnnotateEntryFunc").set_body_typed(AnnotateEntryFunc); -TVM_FFI_REGISTER_GLOBAL("tir.transform.Filter").set_body_typed(Filter); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("tir.transform.AnnotateEntryFunc", AnnotateEntryFunc) + .def("tir.transform.Filter", Filter); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/profile_instrumentation.cc b/src/tir/transforms/profile_instrumentation.cc index 4061a2abf53f..7dd85cf04d65 100644 --- a/src/tir/transforms/profile_instrumentation.cc +++ b/src/tir/transforms/profile_instrumentation.cc @@ -24,6 +24,7 @@ // these instruction can be replaced with a call to a target specific handler // and can be used to capture profiling information such as processor cycles. +#include #include #include #include @@ -283,8 +284,10 @@ Pass InstrumentProfileIntrinsics() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.InstrumentProfileIntrinsics", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.InstrumentProfileIntrinsics") - .set_body_typed(InstrumentProfileIntrinsics); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.InstrumentProfileIntrinsics", InstrumentProfileIntrinsics); +}); } // namespace transform diff --git a/src/tir/transforms/reduce_branching_through_overcompute.cc b/src/tir/transforms/reduce_branching_through_overcompute.cc index 5015d2418a47..b2ec1a464d37 100644 --- a/src/tir/transforms/reduce_branching_through_overcompute.cc +++ b/src/tir/transforms/reduce_branching_through_overcompute.cc @@ -24,6 +24,7 @@ * extra computations that do not impact the final results. */ +#include #include #include @@ -176,8 +177,11 @@ Pass ReduceBranchingThroughOvercompute() { return CreatePrimFuncPass(pass_func, 0, "tir.ReduceBranchingThroughOvercompute", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.ReduceBranchingThroughOvercompute") - .set_body_typed(ReduceBranchingThroughOvercompute); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.ReduceBranchingThroughOvercompute", + ReduceBranchingThroughOvercompute); +}); } // namespace transform diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc index 6afaa0c61583..2a5cb5859df5 100644 --- a/src/tir/transforms/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -21,6 +21,7 @@ * \file remap_thread_axis.cc */ #include +#include #include #include #include @@ -103,7 +104,10 @@ Pass RemapThreadAxis(Map thread_map) { return CreatePrimFuncPass(pass_func, 0, "tir.RemapThreadAxis", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.RemapThreadAxis").set_body_typed(RemapThreadAxis); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.RemapThreadAxis", RemapThreadAxis); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/remove_assume.cc b/src/tir/transforms/remove_assume.cc index ce7176e8cc46..1be33fd63145 100644 --- a/src/tir/transforms/remove_assume.cc +++ b/src/tir/transforms/remove_assume.cc @@ -22,6 +22,7 @@ * \brief Remove stores of tir::builtin::undef */ #include +#include #include #include #include @@ -61,7 +62,10 @@ Pass RemoveAssume() { return Sequential({RemoveAssumeInternal(), RemoveNoOp()}, "tir.RemoveAssume"); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.RemoveAssume").set_body_typed(RemoveAssume); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.RemoveAssume", RemoveAssume); +}); } // namespace transform diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index a67b2bf17878..b410f9697de6 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -333,7 +334,10 @@ Pass RemoveNoOp() { return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.RemoveNoOp").set_body_typed(RemoveNoOp); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.RemoveNoOp", RemoveNoOp); +}); } // namespace transform diff --git a/src/tir/transforms/remove_store_undef.cc b/src/tir/transforms/remove_store_undef.cc index 31b4a558c600..25ede67af953 100644 --- a/src/tir/transforms/remove_store_undef.cc +++ b/src/tir/transforms/remove_store_undef.cc @@ -22,6 +22,7 @@ * \brief Remove stores of tir::builtin::undef */ #include +#include #include #include #include @@ -171,7 +172,10 @@ Pass RemoveStoreUndef() { "tir.RemoveStoreUndef"); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.RemoveStoreUndef").set_body_typed(RemoveStoreUndef); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.RemoveStoreUndef", RemoveStoreUndef); +}); } // namespace transform diff --git a/src/tir/transforms/remove_weight_layout_rewrite_block.cc b/src/tir/transforms/remove_weight_layout_rewrite_block.cc index 0ca4262fc119..d0dda28e4adf 100644 --- a/src/tir/transforms/remove_weight_layout_rewrite_block.cc +++ b/src/tir/transforms/remove_weight_layout_rewrite_block.cc @@ -22,6 +22,7 @@ * \brief Remove weight layout rewrite block before benchmark */ +#include #include #include #include @@ -285,8 +286,11 @@ Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite) { return CreatePrimFuncPass(pass_func, 0, "tir.RemoveWeightLayoutRewriteBlock", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.RemoveWeightLayoutRewriteBlock") - .set_body_typed(RemoveWeightLayoutRewriteBlock); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.RemoveWeightLayoutRewriteBlock", + RemoveWeightLayoutRewriteBlock); +}); } // namespace transform diff --git a/src/tir/transforms/renew_defs.cc b/src/tir/transforms/renew_defs.cc index cd1517b11c2a..8acd8508e7c9 100644 --- a/src/tir/transforms/renew_defs.cc +++ b/src/tir/transforms/renew_defs.cc @@ -22,6 +22,7 @@ * \brief Renew the definition nodes for a TIR, including Var, Buffer and IterVar. */ +#include #include #include @@ -290,7 +291,10 @@ class RenewDefMutator : public StmtExprMutator { PrimFunc RenewDefs(const PrimFunc& func) { return RenewDefMutator::Transform(func); } -TVM_FFI_REGISTER_GLOBAL("tir.RenewDefs").set_body_typed(RenewDefs); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.RenewDefs", RenewDefs); +}); } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/renormalize_split_pattern.cc b/src/tir/transforms/renormalize_split_pattern.cc index 0fb24c62500a..c79d33eb4b05 100644 --- a/src/tir/transforms/renormalize_split_pattern.cc +++ b/src/tir/transforms/renormalize_split_pattern.cc @@ -22,6 +22,7 @@ * \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) */ #include +#include #include #include #include @@ -205,8 +206,10 @@ Pass RenormalizeSplitPattern() { return CreatePrimFuncPass(pass_func, 0, "tir.RenormalizeSplitPattern", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.RenormalizeSplitPattern") - .set_body_typed(RenormalizeSplitPattern); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.RenormalizeSplitPattern", RenormalizeSplitPattern); +}); } // namespace transform diff --git a/src/tir/transforms/rewrite_unsafe_select.cc b/src/tir/transforms/rewrite_unsafe_select.cc index 624e2d9921a9..ceb117ad8ed6 100644 --- a/src/tir/transforms/rewrite_unsafe_select.cc +++ b/src/tir/transforms/rewrite_unsafe_select.cc @@ -22,6 +22,7 @@ * \brief Rewrite uinsafe select expression. */ #include +#include #include #include #include @@ -139,7 +140,10 @@ Pass RewriteUnsafeSelect() { return CreatePrimFuncPass(pass_func, 0, "tir.RewriteUnsafeSelect", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.RewriteUnsafeSelect").set_body_typed(RewriteUnsafeSelect); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.RewriteUnsafeSelect", RewriteUnsafeSelect); +}); } // namespace transform diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 1d8ff9dab0c8..e0a3ce7a1a2a 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -363,7 +364,10 @@ Pass Simplify() { return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.Simplify").set_body_typed(Simplify); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.Simplify", Simplify); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/skip_assert.cc b/src/tir/transforms/skip_assert.cc index 98aea3da99d5..f4ae81b75451 100644 --- a/src/tir/transforms/skip_assert.cc +++ b/src/tir/transforms/skip_assert.cc @@ -18,6 +18,7 @@ */ #include +#include #include #include #include @@ -47,7 +48,10 @@ Pass SkipAssert() { return CreatePrimFuncPass(pass_func, 0, "tir.SkipAssert", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.SkipAssert").set_body_typed(SkipAssert); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.SkipAssert", SkipAssert); +}); } // namespace transform diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 160d80b4dab3..57188ec7659b 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -22,6 +22,7 @@ * \brief Split device function from host. */ #include +#include #include #include #include @@ -165,7 +166,10 @@ Pass SplitHostDevice() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.SplitHostDevice", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.SplitHostDevice").set_body_typed(SplitHostDevice); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.SplitHostDevice", SplitHostDevice); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index bd3afe1d3f84..e6bc94a3d7ba 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -24,6 +24,7 @@ */ #include #include +#include #include #include #include @@ -1762,7 +1763,10 @@ Pass StorageRewrite() { return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.StorageRewrite").set_body_typed(StorageRewrite); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.StorageRewrite", StorageRewrite); +}); Pass PointerValueTypeRewrite() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { @@ -1771,8 +1775,10 @@ Pass PointerValueTypeRewrite() { return CreatePrimFuncPass(pass_func, 0, "tir.PointerValueTypeRewrite", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.PointerValueTypeRewrite") - .set_body_typed(PointerValueTypeRewrite); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.PointerValueTypeRewrite", PointerValueTypeRewrite); +}); } // namespace transform diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index 3c6a6fc9be86..82b573b6d41f 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -22,6 +22,7 @@ * \file tensorcore_fragment.cc */ #include +#include #include #include #include @@ -217,7 +218,10 @@ Pass InferFragment() { return CreatePrimFuncPass(pass_func, 0, "tir.InferFragment", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.InferFragment").set_body_typed(InferFragment); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.InferFragment", InferFragment); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index 34878d1b333d..6f46dbeb8211 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -21,6 +21,7 @@ * \file thread_storage_sync.cc */ #include +#include #include #include #include @@ -471,7 +472,10 @@ Pass ThreadSync(String storage_scope) { return CreatePrimFuncPass(pass_func, 0, "tir.ThreadSync", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.ThreadSync").set_body_typed(ThreadSync); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.ThreadSync", ThreadSync); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/transform_mma_buffer_layout.cc b/src/tir/transforms/transform_mma_buffer_layout.cc index 797caae31100..dc81f56d7567 100644 --- a/src/tir/transforms/transform_mma_buffer_layout.cc +++ b/src/tir/transforms/transform_mma_buffer_layout.cc @@ -18,6 +18,7 @@ */ #include +#include #include #include #include @@ -186,8 +187,10 @@ Pass TransformMmaBufferLayout() { return CreatePrimFuncPass(pass_func, 0, "tir.TransformMmaBufferLayout", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.TransformMmaBufferLayout") - .set_body_typed(TransformMmaBufferLayout); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.TransformMmaBufferLayout", TransformMmaBufferLayout); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/unify_thread_binding.cc b/src/tir/transforms/unify_thread_binding.cc index c83f00d25b82..e6b71a99f02c 100644 --- a/src/tir/transforms/unify_thread_binding.cc +++ b/src/tir/transforms/unify_thread_binding.cc @@ -22,6 +22,7 @@ */ #include +#include #include #include #include @@ -200,7 +201,10 @@ Pass UnifyThreadBinding() { return CreatePrimFuncPass(pass_func, 0, "tir.UnifyThreadBinding", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.UnifyThreadBinding").set_body_typed(UnifyThreadBinding); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.UnifyThreadBinding", UnifyThreadBinding); +}); } // namespace transform diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index 035925b8080c..70375c319191 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -24,6 +24,7 @@ // Unrolls the loop as in Halide pipeline. #include #include +#include #include #include #include @@ -292,7 +293,10 @@ Pass UnrollLoop() { return CreatePrimFuncPass(pass_func, 0, "tir.UnrollLoop", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.UnrollLoop").set_body_typed(UnrollLoop); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.UnrollLoop", UnrollLoop); +}); } // namespace transform diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 8ee1656b3fe4..f25c51160cde 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -22,6 +22,7 @@ * \brief legalize bf16/fp8 type by adding cast_to_fp32 */ #include +#include #include #include #include @@ -758,7 +759,10 @@ Pass BF16ComputeLegalize() { return CreatePrimFuncPass(pass_func, 0, "tir.BF16ComputeLegalize", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16ComputeLegalize); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.BF16ComputeLegalize", BF16ComputeLegalize); +}); Pass BF16StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { @@ -771,7 +775,10 @@ Pass BF16StorageLegalize() { return CreatePrimFuncPass(pass_func, 0, "tir.BF16StorageLegalize", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16StorageLegalize); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.BF16StorageLegalize", BF16StorageLegalize); +}); Pass FP8ComputeLegalize(String promote_dtype_str) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { @@ -784,7 +791,10 @@ Pass FP8ComputeLegalize(String promote_dtype_str) { return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8ComputeLegalize); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.FP8ComputeLegalize", FP8ComputeLegalize); +}); Pass FP8StorageLegalize() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { @@ -797,7 +807,10 @@ Pass FP8StorageLegalize() { return CreatePrimFuncPass(pass_func, 0, "tir.FP8StorageLegalize", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.FP8StorageLegalize").set_body_typed(FP8StorageLegalize); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.FP8StorageLegalize", FP8StorageLegalize); +}); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/using_assume_to_reduce_branches.cc b/src/tir/transforms/using_assume_to_reduce_branches.cc index a1195cfef81f..9d3feb846693 100644 --- a/src/tir/transforms/using_assume_to_reduce_branches.cc +++ b/src/tir/transforms/using_assume_to_reduce_branches.cc @@ -35,6 +35,7 @@ * 4. This pass currently works for op_pattern kElemWise and kBroadcast. */ +#include #include #include #include @@ -381,8 +382,10 @@ Pass UseAssumeToReduceBranches() { return CreatePrimFuncPass(pass_func, 0, "tir.UseAssumeToReduceBranches", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.UseAssumeToReduceBranches") - .set_body_typed(UseAssumeToReduceBranches); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.UseAssumeToReduceBranches", UseAssumeToReduceBranches); +}); } // namespace transform diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index cfe0145d9278..e6637073ec95 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -23,6 +23,7 @@ // Loop vectorizer as in Halide pipeline. #include #include +#include #include #include #include @@ -1024,7 +1025,10 @@ Pass VectorizeLoop(bool enable_vectorize) { return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoop", {}); } -TVM_FFI_REGISTER_GLOBAL("tir.transform.VectorizeLoop").set_body_typed(VectorizeLoop); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.VectorizeLoop", VectorizeLoop); +}); } // namespace transform diff --git a/src/topi/broadcast.cc b/src/topi/broadcast.cc index 1ee85e7b8c95..db819f4bf766 100644 --- a/src/topi/broadcast.cc +++ b/src/topi/broadcast.cc @@ -22,6 +22,7 @@ * \file broadcast.cc */ #include +#include #include #include @@ -72,10 +73,12 @@ TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal); TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal); TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal); -TVM_FFI_REGISTER_GLOBAL("topi.broadcast_to") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = broadcast_to(args[0].cast(), args[1].cast>()); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("topi.broadcast_to", [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = broadcast_to(args[0].cast(), args[1].cast>()); + }); +}); } // namespace topi } // namespace tvm diff --git a/src/topi/einsum.cc b/src/topi/einsum.cc index 40c8332ab725..01adc1744ea7 100644 --- a/src/topi/einsum.cc +++ b/src/topi/einsum.cc @@ -21,6 +21,7 @@ * \file topi/einsum.cc * \brief Einstein summation op */ +#include #include #include @@ -355,8 +356,11 @@ Array InferEinsumShape(const std::string& subscripts, return einsum_builder.InferShape(); } -TVM_FFI_REGISTER_GLOBAL("topi.einsum").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = einsum(args[0].cast(), args[1].cast>()); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("topi.einsum", [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = einsum(args[0].cast(), args[1].cast>()); + }); }); } // namespace topi diff --git a/src/topi/elemwise.cc b/src/topi/elemwise.cc index 13947abcf604..47fe13bfb44e 100644 --- a/src/topi/elemwise.cc +++ b/src/topi/elemwise.cc @@ -22,6 +22,7 @@ * \file elemwise.cc */ #include +#include #include namespace tvm { @@ -30,141 +31,94 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_FFI_REGISTER_GLOBAL("topi.acos").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = acos(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.acosh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = acosh(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.asin").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = asin(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.asinh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = asinh(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.atanh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = atanh(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.exp").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = exp(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.fast_exp").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = fast_exp(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.erf").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = erf(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.fast_erf").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = fast_erf(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.tan").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = tan(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.cos").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = cos(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.cosh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = cosh(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.sin").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = sin(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.sinh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = sinh(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.tanh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = tanh(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.fast_tanh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = fast_tanh(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.atan").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = atan(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.sigmoid").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = sigmoid(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.sqrt").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = sqrt(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.rsqrt").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = rsqrt(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.log").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = log(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.log2").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = log2(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.log10").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = log10(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.identity").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = identity(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.negative").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = negative(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.clip").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = clip(args[0].cast(), args[1].cast(), args[2].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.cast").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = cast(args[0].cast(), args[1].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.reinterpret").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = reinterpret(args[0].cast(), args[1].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.elemwise_sum") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = elemwise_sum(args[0].cast>()); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.sign").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = sign(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.full").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = full(args[0].cast>(), args[1].cast(), args[2].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.full_like").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = full_like(args[0].cast(), args[1].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.logical_not").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = logical_not(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.bitwise_not").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = bitwise_not(args[0].cast()); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("topi.acos", [](ffi::PackedArgs args, + ffi::Any* rv) { *rv = acos(args[0].cast()); }) + .def_packed("topi.acosh", [](ffi::PackedArgs args, + ffi::Any* rv) { *rv = acosh(args[0].cast()); }) + .def_packed("topi.asin", [](ffi::PackedArgs args, + ffi::Any* rv) { *rv = asin(args[0].cast()); }) + .def_packed("topi.asinh", [](ffi::PackedArgs args, + ffi::Any* rv) { *rv = asinh(args[0].cast()); }) + .def_packed("topi.atanh", [](ffi::PackedArgs args, + ffi::Any* rv) { *rv = atanh(args[0].cast()); }) + .def_packed("topi.exp", + [](ffi::PackedArgs args, ffi::Any* rv) { *rv = exp(args[0].cast()); }) + .def_packed("topi.fast_exp", [](ffi::PackedArgs args, + ffi::Any* rv) { *rv = fast_exp(args[0].cast()); }) + .def_packed("topi.erf", + [](ffi::PackedArgs args, ffi::Any* rv) { *rv = erf(args[0].cast()); }) + .def_packed("topi.fast_erf", [](ffi::PackedArgs args, + ffi::Any* rv) { *rv = fast_erf(args[0].cast()); }) + .def_packed("topi.tan", + [](ffi::PackedArgs args, ffi::Any* rv) { *rv = tan(args[0].cast()); }) + .def_packed("topi.cos", + [](ffi::PackedArgs args, ffi::Any* rv) { *rv = cos(args[0].cast()); }) + .def_packed("topi.cosh", [](ffi::PackedArgs args, + ffi::Any* rv) { *rv = cosh(args[0].cast()); }) + .def_packed("topi.sin", + [](ffi::PackedArgs args, ffi::Any* rv) { *rv = sin(args[0].cast()); }) + .def_packed("topi.sinh", [](ffi::PackedArgs args, + ffi::Any* rv) { *rv = sinh(args[0].cast()); }) + .def_packed("topi.tanh", [](ffi::PackedArgs args, + ffi::Any* rv) { *rv = tanh(args[0].cast()); }) + .def_packed( + "topi.fast_tanh", + [](ffi::PackedArgs args, ffi::Any* rv) { *rv = fast_tanh(args[0].cast()); }) + .def_packed("topi.atan", [](ffi::PackedArgs args, + ffi::Any* rv) { *rv = atan(args[0].cast()); }) + .def_packed("topi.sigmoid", [](ffi::PackedArgs args, + ffi::Any* rv) { *rv = sigmoid(args[0].cast()); }) + .def_packed("topi.sqrt", [](ffi::PackedArgs args, + ffi::Any* rv) { *rv = sqrt(args[0].cast()); }) + .def_packed("topi.rsqrt", [](ffi::PackedArgs args, + ffi::Any* rv) { *rv = rsqrt(args[0].cast()); }) + .def_packed("topi.log", + [](ffi::PackedArgs args, ffi::Any* rv) { *rv = log(args[0].cast()); }) + .def_packed("topi.log2", [](ffi::PackedArgs args, + ffi::Any* rv) { *rv = log2(args[0].cast()); }) + .def_packed("topi.log10", [](ffi::PackedArgs args, + ffi::Any* rv) { *rv = log10(args[0].cast()); }) + .def_packed("topi.identity", [](ffi::PackedArgs args, + ffi::Any* rv) { *rv = identity(args[0].cast()); }) + .def_packed("topi.negative", [](ffi::PackedArgs args, + ffi::Any* rv) { *rv = negative(args[0].cast()); }) + .def_packed("topi.clip", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = clip(args[0].cast(), args[1].cast(), + args[2].cast()); + }) + .def_packed("topi.cast", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = cast(args[0].cast(), args[1].cast()); + }) + .def_packed("topi.reinterpret", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = reinterpret(args[0].cast(), args[1].cast()); + }) + .def_packed("topi.elemwise_sum", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = elemwise_sum(args[0].cast>()); + }) + .def_packed("topi.sign", [](ffi::PackedArgs args, + ffi::Any* rv) { *rv = sign(args[0].cast()); }) + .def_packed("topi.full", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = full(args[0].cast>(), args[1].cast(), + args[2].cast()); + }) + .def_packed("topi.full_like", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = full_like(args[0].cast(), args[1].cast()); + }) + .def_packed( + "topi.logical_not", + [](ffi::PackedArgs args, ffi::Any* rv) { *rv = logical_not(args[0].cast()); }) + .def_packed("topi.bitwise_not", [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = bitwise_not(args[0].cast()); + }); }); } // namespace topi diff --git a/src/topi/nn.cc b/src/topi/nn.cc index 4b2095a53868..a658785bd691 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -22,6 +22,7 @@ * \file nn.cc */ #include +#include #include #include #include @@ -44,191 +45,232 @@ using namespace tvm; using namespace tvm::runtime; /* Ops from nn.h */ -TVM_FFI_REGISTER_GLOBAL("topi.nn.relu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = relu(args[0].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.nn.leaky_relu") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = leaky_relu(args[0].cast(), args[1].cast()); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.nn.prelu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = prelu(args[0].cast(), args[1].cast(), args[2].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.nn.pad").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = pad(args[0].cast(), args[1].cast>(), - args[2].cast>(), args[3].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.nn.space_to_batch_nd") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = space_to_batch_nd(args[0].cast(), args[1].cast>(), - args[2].cast>(), args[3].cast>(), - args[4].cast()); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.nn.batch_to_space_nd") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = batch_to_space_nd(args[0].cast(), args[1].cast>(), - args[2].cast>(), args[3].cast>(), - args[4].cast()); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.nn.nll_loss").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nll_loss(args[0].cast(), args[1].cast(), args[2].cast(), - args[3].cast(), args[4].cast()); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed( + "topi.nn.relu", + [](ffi::PackedArgs args, ffi::Any* rv) { *rv = relu(args[0].cast()); }) + .def_packed("topi.nn.leaky_relu", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = leaky_relu(args[0].cast(), args[1].cast()); + }) + .def_packed("topi.nn.prelu", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = prelu(args[0].cast(), args[1].cast(), + args[2].cast()); + }) + .def_packed("topi.nn.pad", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = pad(args[0].cast(), args[1].cast>(), + args[2].cast>(), args[3].cast()); + }) + .def_packed("topi.nn.space_to_batch_nd", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = space_to_batch_nd( + args[0].cast(), args[1].cast>(), + args[2].cast>(), args[3].cast>(), + args[4].cast()); + }) + .def_packed("topi.nn.batch_to_space_nd", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = batch_to_space_nd( + args[0].cast(), args[1].cast>(), + args[2].cast>(), args[3].cast>(), + args[4].cast()); + }) + .def_packed("topi.nn.nll_loss", [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = + nll_loss(args[0].cast(), args[1].cast(), + args[2].cast(), args[3].cast(), args[4].cast()); + }); }); /* Ops from nn/dense.h */ -TVM_FFI_REGISTER_GLOBAL("topi.nn.dense").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::dense(args[0].cast(), args[1].cast(), - args[2].cast(), args[3].cast()); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("topi.nn.dense", [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::dense(args[0].cast(), args[1].cast(), + args[2].cast(), args[3].cast()); + }); }); /* Ops from nn/bias_add.h */ -TVM_FFI_REGISTER_GLOBAL("topi.nn.bias_add").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::bias_add(args[0].cast(), args[1].cast(), args[2].cast()); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("topi.nn.bias_add", [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::bias_add(args[0].cast(), args[1].cast(), args[2].cast()); + }); }); /* Ops from nn/dilate.h */ -TVM_FFI_REGISTER_GLOBAL("topi.nn.dilate").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::dilate(args[0].cast(), args[1].cast>(), - args[2].cast()); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("topi.nn.dilate", [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::dilate(args[0].cast(), args[1].cast>(), + args[2].cast()); + }); }); /* Ops from nn/flatten.h */ -TVM_FFI_REGISTER_GLOBAL("topi.nn.flatten").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::flatten(args[0].cast()); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("topi.nn.flatten", [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::flatten(args[0].cast()); + }); }); /* Ops from nn/mapping.h */ -TVM_FFI_REGISTER_GLOBAL("topi.nn.scale_shift_nchw") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::scale_shift_nchw(args[0].cast(), args[1].cast(), - args[2].cast()); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.nn.scale_shift_nhwc") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::scale_shift_nhwc(args[0].cast(), args[1].cast(), - args[2].cast()); - }); - -/* Ops from nn/pooling.h */ -TVM_FFI_REGISTER_GLOBAL("topi.nn.pool_grad") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::pool_grad(args[0].cast(), args[1].cast(), - args[2].cast>(), args[3].cast>(), - args[4].cast>(), - static_cast(args[5].cast()), args[6].cast(), - args[7].cast(), args[8].cast()); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.nn.global_pool") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::global_pool(args[0].cast(), - static_cast(args[1].cast()), - args[2].cast()); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.nn.adaptive_pool1d") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::adaptive_pool1d(args[0].cast(), args[1].cast>(), - static_cast(args[2].cast()), - args[3].cast()); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.nn.adaptive_pool") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::adaptive_pool(args[0].cast(), args[1].cast>(), - static_cast(args[2].cast()), - args[3].cast()); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.nn.adaptive_pool3d") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::adaptive_pool3d(args[0].cast(), args[1].cast>(), - static_cast(args[2].cast()), - args[3].cast()); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.nn.pool1d").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::pool1d(args[0].cast(), args[1].cast>(), - args[2].cast>(), args[3].cast>(), - args[4].cast>(), static_cast(args[5].cast()), - args[6].cast(), args[7].cast(), args[8].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.nn.pool2d").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::pool2d(args[0].cast(), args[1].cast>(), - args[2].cast>(), args[3].cast>(), - args[4].cast>(), static_cast(args[5].cast()), - args[6].cast(), args[7].cast(), args[8].cast()); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("topi.nn.scale_shift_nchw", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = + nn::scale_shift_nchw(args[0].cast(), args[1].cast(), + args[2].cast()); + }) + .def_packed("topi.nn.scale_shift_nhwc", [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::scale_shift_nhwc(args[0].cast(), args[1].cast(), + args[2].cast()); + }); }); -TVM_FFI_REGISTER_GLOBAL("topi.nn.pool3d").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::pool3d(args[0].cast(), args[1].cast>(), - args[2].cast>(), args[3].cast>(), - args[4].cast>(), static_cast(args[5].cast()), - args[6].cast(), args[7].cast(), args[8].cast()); +/* Ops from nn/pooling.h */ +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("topi.nn.pool_grad", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::pool_grad( + args[0].cast(), args[1].cast(), + args[2].cast>(), args[3].cast>(), + args[4].cast>(), + static_cast(args[5].cast()), args[6].cast(), + args[7].cast(), args[8].cast()); + }) + .def_packed("topi.nn.global_pool", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::global_pool(args[0].cast(), + static_cast(args[1].cast()), + args[2].cast()); + }) + .def_packed("topi.nn.adaptive_pool1d", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::adaptive_pool1d(args[0].cast(), + args[1].cast>(), + static_cast(args[2].cast()), + args[3].cast()); + }) + .def_packed("topi.nn.adaptive_pool", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::adaptive_pool(args[0].cast(), + args[1].cast>(), + static_cast(args[2].cast()), + args[3].cast()); + }) + .def_packed("topi.nn.adaptive_pool3d", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::adaptive_pool3d(args[0].cast(), + args[1].cast>(), + static_cast(args[2].cast()), + args[3].cast()); + }) + .def_packed("topi.nn.pool1d", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::pool1d( + args[0].cast(), args[1].cast>(), + args[2].cast>(), args[3].cast>(), + args[4].cast>(), + static_cast(args[5].cast()), args[6].cast(), + args[7].cast(), args[8].cast()); + }) + .def_packed("topi.nn.pool2d", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::pool2d( + args[0].cast(), args[1].cast>(), + args[2].cast>(), args[3].cast>(), + args[4].cast>(), + static_cast(args[5].cast()), args[6].cast(), + args[7].cast(), args[8].cast()); + }) + .def_packed("topi.nn.pool3d", [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::pool3d(args[0].cast(), args[1].cast>(), + args[2].cast>(), args[3].cast>(), + args[4].cast>(), + static_cast(args[5].cast()), args[6].cast(), + args[7].cast(), args[8].cast()); + }); }); /* Ops from nn/softmax.h */ -TVM_FFI_REGISTER_GLOBAL("topi.nn.softmax").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::softmax(args[0].cast(), args[1].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.nn.log_softmax") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::log_softmax(args[0].cast()); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.nn.lrn").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::lrn(args[0].cast(), args[1].cast(), args[2].cast(), - args[3].cast(), args[4].cast(), args[5].cast()); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("topi.nn.softmax", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::softmax(args[0].cast(), args[1].cast()); + }) + .def_packed("topi.nn.log_softmax", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::log_softmax(args[0].cast()); + }) + .def_packed("topi.nn.lrn", [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::lrn(args[0].cast(), args[1].cast(), args[2].cast(), + args[3].cast(), args[4].cast(), args[5].cast()); + }); }); /* Ops from nn/bnn.h */ -TVM_FFI_REGISTER_GLOBAL("topi.nn.binarize_pack") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::binarize_pack(args[0].cast(), args[1].cast()); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.nn.binary_dense") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::binary_dense(args[0].cast(), args[1].cast()); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("topi.nn.binarize_pack", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::binarize_pack(args[0].cast(), args[1].cast()); + }) + .def_packed("topi.nn.binary_dense", [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::binary_dense(args[0].cast(), args[1].cast()); + }); +}); /* Ops from nn/layer_norm.h */ -TVM_FFI_REGISTER_GLOBAL("topi.nn.layer_norm") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::layer_norm(args[0].cast(), args[1].cast(), - args[2].cast(), args[3].cast>(), - args[4].cast()); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("topi.nn.layer_norm", [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::layer_norm(args[0].cast(), args[1].cast(), + args[2].cast(), args[3].cast>(), + args[4].cast()); + }); +}); /* Ops from nn/group_norm.h */ -TVM_FFI_REGISTER_GLOBAL("topi.nn.group_norm") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::group_norm(args[0].cast(), args[1].cast(), - args[2].cast(), args[3].cast(), args[4].cast(), - args[5].cast>(), args[6].cast()); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("topi.nn.group_norm", [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::group_norm(args[0].cast(), args[1].cast(), + args[2].cast(), args[3].cast(), args[4].cast(), + args[5].cast>(), args[6].cast()); + }); +}); /* Ops from nn/instance_norm.h */ -TVM_FFI_REGISTER_GLOBAL("topi.nn.instance_norm") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::instance_norm(args[0].cast(), args[1].cast(), - args[2].cast(), args[3].cast(), - args[4].cast>(), args[5].cast()); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("topi.nn.instance_norm", [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::instance_norm(args[0].cast(), args[1].cast(), + args[2].cast(), args[3].cast(), + args[4].cast>(), args[5].cast()); + }); +}); /* Ops from nn/rms_norm.h */ -TVM_FFI_REGISTER_GLOBAL("topi.nn.rms_norm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::rms_norm(args[0].cast(), args[1].cast(), - args[2].cast>(), args[3].cast()); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("topi.nn.rms_norm", [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::rms_norm(args[0].cast(), args[1].cast(), + args[2].cast>(), args[3].cast()); + }); }); } // namespace topi diff --git a/src/topi/reduction.cc b/src/topi/reduction.cc index f8920bdefd46..42988ae762d1 100644 --- a/src/topi/reduction.cc +++ b/src/topi/reduction.cc @@ -22,6 +22,7 @@ * \file reduction.cc */ #include +#include #include #include @@ -31,44 +32,53 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_FFI_REGISTER_GLOBAL("topi.sum").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = topi::sum(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("topi.sum", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = topi::sum(args[0].cast(), ArrayOrInt(args[1]), + args[2].cast()); + }) + .def_packed("topi.min", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = topi::min(args[0].cast(), ArrayOrInt(args[1]), + args[2].cast()); + }) + .def_packed("topi.max", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = topi::max(args[0].cast(), ArrayOrInt(args[1]), + args[2].cast()); + }) + .def_packed("topi.argmin", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = topi::argmin(args[0].cast(), ArrayOrInt(args[1]), + args[2].cast(), false, args[3].cast()); + }) + .def_packed("topi.argmax", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = topi::argmax(args[0].cast(), ArrayOrInt(args[1]), + args[2].cast(), false, args[3].cast()); + }) + .def_packed("topi.prod", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = topi::prod(args[0].cast(), ArrayOrInt(args[1]), + args[2].cast()); + }) + .def_packed("topi.all", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = topi::all(args[0].cast(), ArrayOrInt(args[1]), + args[2].cast()); + }) + .def_packed("topi.any", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = topi::any(args[0].cast(), ArrayOrInt(args[1]), + args[2].cast()); + }) + .def_packed("topi.collapse_sum", [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = topi::collapse_sum(args[0].cast(), args[1].cast>()); + }); }); -TVM_FFI_REGISTER_GLOBAL("topi.min").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = topi::min(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.max").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = topi::max(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.argmin").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = topi::argmin(args[0].cast(), ArrayOrInt(args[1]), args[2].cast(), false, - args[3].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.argmax").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = topi::argmax(args[0].cast(), ArrayOrInt(args[1]), args[2].cast(), false, - args[3].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.prod").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = topi::prod(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.all").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = topi::all(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.any").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = topi::any(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.collapse_sum") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = topi::collapse_sum(args[0].cast(), args[1].cast>()); - }); - } // namespace topi } // namespace tvm diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 5826fdac864f..62e9783b4dc9 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -22,6 +22,7 @@ * \file transform.cc */ #include +#include #include #include #include @@ -36,226 +37,237 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_FFI_REGISTER_GLOBAL("topi.expand_dims").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = expand_dims(args[0].cast(), args[1].cast(), args[2].cast()); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("topi.expand_dims", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = expand_dims(args[0].cast(), args[1].cast(), + args[2].cast()); + }) + .def_packed("topi.transpose", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = transpose(args[0].cast(), + args[1].cast>>()); + }) + .def_packed("topi.flip", + [](ffi::PackedArgs args, ffi::Any* rv) { + // pass empty seq_lengths tensor to reverse_sequence + *rv = + reverse_sequence(args[0].cast(), Tensor(), args[1].cast()); + }) + .def_packed("topi.reverse_sequence", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = reverse_sequence(args[0].cast(), args[1].cast(), + args[2].cast()); + }) + .def_packed("topi.reshape", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = reshape(args[0].cast(), args[1].cast>()); + }) + .def_packed("topi.sliding_window", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = sliding_window(args[0].cast(), args[1].cast(), + args[2].cast>(), + args[3].cast>()); + }) + .def_packed("topi.squeeze", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = squeeze(args[0].cast(), ArrayOrInt(args[1])); + }) + .def_packed("topi.concatenate", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = concatenate(args[0].cast>(), args[1].cast()); + }) + .def_packed("topi.stack", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = stack(args[0].cast>(), args[1].cast()); + }) + .def_packed("topi.shape", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = shape(args[0].cast(), args[1].cast()); + }) + .def_packed("topi.ndarray_size", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = ndarray_size(args[0].cast(), args[1].cast()); + }) + .def_packed("topi.split", + [](ffi::PackedArgs args, ffi::Any* rv) { + if (args[1].try_cast()) { + *rv = split_n_sections(args[0].cast(), args[1].cast(), + args[2].cast()); + } else { + *rv = + split_indices_array(args[0].cast(), + args[1].cast>(), args[2].cast()); + } + }) + .def_packed("topi.layout_transform", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = + layout_transform(args[0].cast(), args[1].cast(), + args[2].cast(), args[3].cast()); + }) + .def_packed( + "topi.take", + [](ffi::PackedArgs args, ffi::Any* rv) { + if (args.size() == 4) { + auto mode = args[3].cast(); + int batch_dims = args[2].cast(); + *rv = take(args[0].cast(), args[1].cast(), batch_dims, mode); + } else { + ICHECK_EQ(args.size(), 5) << "topi.take expects 4 or 5 arguments"; + int batch_dims = args[2].cast(); + int axis = args[3].cast(); + auto mode = args[4].cast(); + *rv = + take(args[0].cast(), + args[1].cast>(), batch_dims, axis, mode); + } + }) + .def_packed("topi.sequence_mask", + [](ffi::PackedArgs args, ffi::Any* rv) { + double pad_val = args[2].cast(); + int axis = args[3].cast(); + *rv = sequence_mask(args[0].cast(), args[1].cast(), + pad_val, axis); + }) + .def_packed("topi.where", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = where(args[0].cast(), args[1].cast(), + args[2].cast()); + }) + .def_packed("topi.arange", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = arange(args[0].cast(), args[1].cast(), + args[2].cast(), args[3].cast()); + }) + .def_packed("topi.meshgrid", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = meshgrid(args[0].cast>(), args[1].cast()); + }) + .def_packed("topi.repeat", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = repeat(args[0].cast(), args[1].cast(), + args[2].cast()); + }) + .def_packed("topi.tile", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = tile(args[0].cast(), args[1].cast>()); + }) + .def_packed("topi.gather", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = gather(args[0].cast(), args[1].cast(), + args[2].cast()); + }) + .def_packed("topi.gather_nd", + [](ffi::PackedArgs args, ffi::Any* rv) { + int batch_dims = args[2].cast(); + *rv = gather_nd(args[0].cast(), args[1].cast(), + batch_dims); + }) + .def_packed("topi.unravel_index", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = unravel_index(args[0].cast(), args[1].cast()); + }) + .def_packed("topi.sparse_to_dense", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = + sparse_to_dense(args[0].cast(), args[1].cast>(), + args[2].cast(), args[3].cast()); + }) + .def_packed("topi.matmul", + [](ffi::PackedArgs args, ffi::Any* rv) { + switch (args.size()) { + case 2: + *rv = matmul(args[0].cast(), args[1].cast()); + break; + case 3: + *rv = matmul(args[0].cast(), args[1].cast(), + args[2].cast()); + break; + case 4: + *rv = matmul(args[0].cast(), args[1].cast(), + args[2].cast(), args[3].cast()); + break; + default: + ICHECK(0) << "topi.matmul expects 2, 3 or 4 arguments"; + } + }) + .def_packed("topi.tensordot", + [](ffi::PackedArgs args, ffi::Any* rv) { + if (args.size() == 2) { + *rv = tensordot(args[0].cast(), args[1].cast()); + } else if (args.size() == 3) { + *rv = tensordot(args[0].cast(), args[1].cast(), + args[2].cast()); + } else { + Array axes = args[3].cast>(); + *rv = tensordot(args[0].cast(), args[1].cast(), + args[2].cast>(), axes); + } + }) + .def_packed( + "topi.strided_slice", + [](ffi::PackedArgs args, ffi::Any* rv) { + Tensor x = args[0].cast(); + Array begin = args[1].cast>(); + Array end = args[2].cast>(); + Array strides = args[3].cast>(); + Array axes = args[4].cast>(); + bool assume_inbound = args[6].cast(); + if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides) && + IsConstIntArray(x->shape)) { + Array begin_static = args[1].cast>(); + Array end_static = args[2].cast>(); + Array strides_static = args[3].cast>(); + auto slice_mode = args[5].cast(); + if (axes.size()) { + *rv = strided_slice_with_axes(x, begin_static, end_static, strides_static, axes, + slice_mode); + } else { + *rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode); + } + } else { + if (axes.size()) { + *rv = dynamic_strided_slice_with_axes(x, begin, end, strides, axes, assume_inbound); + } else { + *rv = dynamic_strided_slice(x, begin, end, strides, assume_inbound); + } + } + }) + .def_packed("topi.dynamic_strided_slice", + [](ffi::PackedArgs args, ffi::Any* rv) { + te::Tensor begin = args[1].cast(); + te::Tensor end = args[2].cast(); + te::Tensor strides = args[3].cast(); + *rv = dynamic_strided_slice(args[0].cast(), begin, end, strides); + }) + .def("topi.relax_dynamic_strided_slice", + [](te::Tensor x, te::Tensor begin, te::Tensor end, te::Tensor strides, + Array output_shape) { + return relax::dynamic_strided_slice(x, begin, end, strides, output_shape); + }) + .def_packed("topi.one_hot", + [](ffi::PackedArgs args, ffi::Any* rv) { + int depth = args[3].cast(); + int axis = args[4].cast(); + DataType dtype = args[5].cast(); + *rv = one_hot(args[0].cast(), args[1].cast(), + args[2].cast(), depth, axis, dtype); + }) + .def_packed("topi.matrix_set_diag", + [](ffi::PackedArgs args, ffi::Any* rv) { + int k1 = args[2].cast(); + int k2 = args[3].cast(); + bool super_diag_right_align = args[4].cast(); + bool sub_diag_right_align = args[5].cast(); + *rv = matrix_set_diag(args[0].cast(), args[1].cast(), + k1, k2, super_diag_right_align, sub_diag_right_align); + }) + .def("topi.adv_index", + [](te::Tensor x, Array indices) { return adv_index(x, indices); }); }); -TVM_FFI_REGISTER_GLOBAL("topi.transpose").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = transpose(args[0].cast(), args[1].cast>>()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.flip").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - // pass empty seq_lengths tensor to reverse_sequence - *rv = reverse_sequence(args[0].cast(), Tensor(), args[1].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.reverse_sequence") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = reverse_sequence(args[0].cast(), args[1].cast(), - args[2].cast()); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.reshape").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = reshape(args[0].cast(), args[1].cast>()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.sliding_window") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = sliding_window(args[0].cast(), args[1].cast(), - args[2].cast>(), args[3].cast>()); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.squeeze").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = squeeze(args[0].cast(), ArrayOrInt(args[1])); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.concatenate").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = concatenate(args[0].cast>(), args[1].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.stack").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = stack(args[0].cast>(), args[1].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.shape").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = shape(args[0].cast(), args[1].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.ndarray_size") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = ndarray_size(args[0].cast(), args[1].cast()); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.split").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - if (args[1].try_cast()) { - *rv = split_n_sections(args[0].cast(), args[1].cast(), args[2].cast()); - } else { - *rv = split_indices_array(args[0].cast(), args[1].cast>(), - args[2].cast()); - } -}); - -TVM_FFI_REGISTER_GLOBAL("topi.layout_transform") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = layout_transform(args[0].cast(), args[1].cast(), - args[2].cast(), args[3].cast()); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.take").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - if (args.size() == 4) { - auto mode = args[3].cast(); - int batch_dims = args[2].cast(); - *rv = take(args[0].cast(), args[1].cast(), batch_dims, mode); - } else { - ICHECK_EQ(args.size(), 5) << "topi.take expects 4 or 5 arguments"; - int batch_dims = args[2].cast(); - int axis = args[3].cast(); - auto mode = args[4].cast(); - *rv = take(args[0].cast(), args[1].cast>(), - batch_dims, axis, mode); - } -}); - -TVM_FFI_REGISTER_GLOBAL("topi.sequence_mask") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - double pad_val = args[2].cast(); - int axis = args[3].cast(); - *rv = sequence_mask(args[0].cast(), args[1].cast(), pad_val, axis); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.where").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = where(args[0].cast(), args[1].cast(), args[2].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.arange").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = arange(args[0].cast(), args[1].cast(), args[2].cast(), - args[3].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.meshgrid").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = meshgrid(args[0].cast>(), args[1].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.repeat").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = repeat(args[0].cast(), args[1].cast(), args[2].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.tile").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = tile(args[0].cast(), args[1].cast>()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.gather").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = gather(args[0].cast(), args[1].cast(), args[2].cast()); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.gather_nd").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - int batch_dims = args[2].cast(); - *rv = gather_nd(args[0].cast(), args[1].cast(), batch_dims); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.unravel_index") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = unravel_index(args[0].cast(), args[1].cast()); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.sparse_to_dense") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = sparse_to_dense(args[0].cast(), args[1].cast>(), - args[2].cast(), args[3].cast()); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.matmul").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - switch (args.size()) { - case 2: - *rv = matmul(args[0].cast(), args[1].cast()); - break; - case 3: - *rv = matmul(args[0].cast(), args[1].cast(), args[2].cast()); - break; - case 4: - *rv = matmul(args[0].cast(), args[1].cast(), args[2].cast(), - args[3].cast()); - break; - default: - ICHECK(0) << "topi.matmul expects 2, 3 or 4 arguments"; - } -}); - -TVM_FFI_REGISTER_GLOBAL("topi.tensordot").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - if (args.size() == 2) { - *rv = tensordot(args[0].cast(), args[1].cast()); - } else if (args.size() == 3) { - *rv = tensordot(args[0].cast(), args[1].cast(), args[2].cast()); - } else { - Array axes = args[3].cast>(); - *rv = tensordot(args[0].cast(), args[1].cast(), - args[2].cast>(), axes); - } -}); - -TVM_FFI_REGISTER_GLOBAL("topi.strided_slice") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - Tensor x = args[0].cast(); - Array begin = args[1].cast>(); - Array end = args[2].cast>(); - Array strides = args[3].cast>(); - Array axes = args[4].cast>(); - bool assume_inbound = args[6].cast(); - if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides) && - IsConstIntArray(x->shape)) { - Array begin_static = args[1].cast>(); - Array end_static = args[2].cast>(); - Array strides_static = args[3].cast>(); - auto slice_mode = args[5].cast(); - if (axes.size()) { - *rv = strided_slice_with_axes(x, begin_static, end_static, strides_static, axes, - slice_mode); - } else { - *rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode); - } - } else { - if (axes.size()) { - *rv = dynamic_strided_slice_with_axes(x, begin, end, strides, axes, assume_inbound); - } else { - *rv = dynamic_strided_slice(x, begin, end, strides, assume_inbound); - } - } - }); - -TVM_FFI_REGISTER_GLOBAL("topi.dynamic_strided_slice") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - te::Tensor begin = args[1].cast(); - te::Tensor end = args[2].cast(); - te::Tensor strides = args[3].cast(); - *rv = dynamic_strided_slice(args[0].cast(), begin, end, strides); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.relax_dynamic_strided_slice") - .set_body_typed([](te::Tensor x, te::Tensor begin, te::Tensor end, te::Tensor strides, - Array output_shape) { - return relax::dynamic_strided_slice(x, begin, end, strides, output_shape); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.one_hot").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - int depth = args[3].cast(); - int axis = args[4].cast(); - DataType dtype = args[5].cast(); - *rv = one_hot(args[0].cast(), args[1].cast(), args[2].cast(), - depth, axis, dtype); -}); - -TVM_FFI_REGISTER_GLOBAL("topi.matrix_set_diag") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - int k1 = args[2].cast(); - int k2 = args[3].cast(); - bool super_diag_right_align = args[4].cast(); - bool sub_diag_right_align = args[5].cast(); - *rv = matrix_set_diag(args[0].cast(), args[1].cast(), k1, k2, - super_diag_right_align, sub_diag_right_align); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.adv_index") - .set_body_typed([](te::Tensor x, Array indices) { return adv_index(x, indices); }); - } // namespace topi } // namespace tvm diff --git a/src/topi/utils.cc b/src/topi/utils.cc index 9a668ad2ac17..de4b9d0a54f5 100644 --- a/src/topi/utils.cc +++ b/src/topi/utils.cc @@ -23,28 +23,30 @@ */ #include +#include #include namespace tvm { namespace topi { -TVM_FFI_REGISTER_GLOBAL("topi.utils.is_empty_shape") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = topi::detail::is_empty_shape(args[0].cast>()); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.utils.bilinear_sample_nchw") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = - detail::bilinear_sample_nchw(args[0].cast(), args[1].cast>(), - args[2].cast(), args[3].cast()); - }); - -TVM_FFI_REGISTER_GLOBAL("topi.utils.bilinear_sample_nhwc") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = - detail::bilinear_sample_nhwc(args[0].cast(), args[1].cast>(), - args[2].cast(), args[3].cast()); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("topi.utils.is_empty_shape", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = topi::detail::is_empty_shape(args[0].cast>()); + }) + .def_packed("topi.utils.bilinear_sample_nchw", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = detail::bilinear_sample_nchw( + args[0].cast(), args[1].cast>(), + args[2].cast(), args[3].cast()); + }) + .def_packed("topi.utils.bilinear_sample_nhwc", [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = detail::bilinear_sample_nhwc(args[0].cast(), + args[1].cast>(), + args[2].cast(), args[3].cast()); + }); +}); } // namespace topi } // namespace tvm diff --git a/src/topi/vision.cc b/src/topi/vision.cc index 57d936268010..36388076e914 100644 --- a/src/topi/vision.cc +++ b/src/topi/vision.cc @@ -22,6 +22,7 @@ * \file vision.cc */ #include +#include #include namespace tvm { @@ -30,10 +31,12 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_FFI_REGISTER_GLOBAL("topi.vision.reorg") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = vision::reorg(args[0].cast(), args[1].cast()); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("topi.vision.reorg", [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = vision::reorg(args[0].cast(), args[1].cast()); + }); +}); } // namespace topi } // namespace tvm diff --git a/tests/cpp-runtime/hexagon/run_all_tests.cc b/tests/cpp-runtime/hexagon/run_all_tests.cc index cf8160971a51..f84dd74dbfba 100644 --- a/tests/cpp-runtime/hexagon/run_all_tests.cc +++ b/tests/cpp-runtime/hexagon/run_all_tests.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -37,31 +38,33 @@ namespace tvm { namespace runtime { namespace hexagon { -TVM_FFI_REGISTER_GLOBAL("hexagon.run_all_tests") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - // gtest args are passed into this packed func as a singular string - // split gtest args using delimiter and build argument vector - std::vector parsed_args = tvm::support::Split(args[0].cast(), ' '); - std::vector argv; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("hexagon.run_all_tests", [](ffi::PackedArgs args, ffi::Any* rv) { + // gtest args are passed into this packed func as a singular string + // split gtest args using delimiter and build argument vector + std::vector parsed_args = tvm::support::Split(args[0].cast(), ' '); + std::vector argv; - // add executable name - argv.push_back(const_cast("hexagon_run_all_tests")); + // add executable name + argv.push_back(const_cast("hexagon_run_all_tests")); - // add parsed arguments - for (int i = 0; i < parsed_args.size(); ++i) { - argv.push_back(const_cast(parsed_args[i].data())); - } + // add parsed arguments + for (int i = 0; i < parsed_args.size(); ++i) { + argv.push_back(const_cast(parsed_args[i].data())); + } - // end of parsed arguments - argv.push_back(nullptr); + // end of parsed arguments + argv.push_back(nullptr); - // set argument count - int argc = argv.size() - 1; + // set argument count + int argc = argv.size() - 1; - // initialize gtest with arguments and run - ::testing::InitGoogleTest(&argc, argv.data()); - *rv = RUN_ALL_TESTS(); - }); + // initialize gtest with arguments and run + ::testing::InitGoogleTest(&argc, argv.data()); + *rv = RUN_ALL_TESTS(); + }); +}); } // namespace hexagon } // namespace runtime diff --git a/tests/cpp-runtime/hexagon/run_unit_tests.cc b/tests/cpp-runtime/hexagon/run_unit_tests.cc index a4c613b41140..ebcbd49e9734 100644 --- a/tests/cpp-runtime/hexagon/run_unit_tests.cc +++ b/tests/cpp-runtime/hexagon/run_unit_tests.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -79,43 +80,45 @@ class GtestPrinter : public testing::EmptyTestEventListener { std::string GetOutput() { return gtest_out_.str(); } }; -TVM_FFI_REGISTER_GLOBAL("hexagon.run_unit_tests") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - // gtest args are passed into this packed func as a singular string - // split gtest args using delimiter and build argument vector - std::vector parsed_args = tvm::support::Split(args[0].cast(), ' '); - std::vector argv; - - // add executable name - argv.push_back(const_cast("hexagon_run_unit_tests")); - - // add parsed arguments - for (int i = 0; i < parsed_args.size(); ++i) { - argv.push_back(const_cast(parsed_args[i].data())); - } - - // end of parsed arguments - argv.push_back(nullptr); - - // set argument count - int argc = argv.size() - 1; - - // initialize gtest with arguments and run - ::testing::InitGoogleTest(&argc, argv.data()); - - // add printer to capture gtest output in a string - GtestPrinter* gprinter = new GtestPrinter(); - testing::TestEventListeners& listeners = testing::UnitTest::GetInstance()->listeners(); - listeners.Append(gprinter); - - int gtest_error_code = RUN_ALL_TESTS(); - std::string gtest_output = gprinter->GetOutput(); - std::stringstream gtest_error_code_and_output; - gtest_error_code_and_output << gtest_error_code << std::endl; - gtest_error_code_and_output << gtest_output; - *rv = gtest_error_code_and_output.str(); - delete gprinter; - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("hexagon.run_unit_tests", [](ffi::PackedArgs args, ffi::Any* rv) { + // gtest args are passed into this packed func as a singular string + // split gtest args using delimiter and build argument vector + std::vector parsed_args = tvm::support::Split(args[0].cast(), ' '); + std::vector argv; + + // add executable name + argv.push_back(const_cast("hexagon_run_unit_tests")); + + // add parsed arguments + for (int i = 0; i < parsed_args.size(); ++i) { + argv.push_back(const_cast(parsed_args[i].data())); + } + + // end of parsed arguments + argv.push_back(nullptr); + + // set argument count + int argc = argv.size() - 1; + + // initialize gtest with arguments and run + ::testing::InitGoogleTest(&argc, argv.data()); + + // add printer to capture gtest output in a string + GtestPrinter* gprinter = new GtestPrinter(); + testing::TestEventListeners& listeners = testing::UnitTest::GetInstance()->listeners(); + listeners.Append(gprinter); + + int gtest_error_code = RUN_ALL_TESTS(); + std::string gtest_output = gprinter->GetOutput(); + std::stringstream gtest_error_code_and_output; + gtest_error_code_and_output << gtest_error_code << std::endl; + gtest_error_code_and_output << gtest_output; + *rv = gtest_error_code_and_output.str(); + delete gprinter; + }); +}); } // namespace hexagon } // namespace runtime diff --git a/tests/lint/git-clang-format.sh b/tests/lint/git-clang-format.sh index 22e583377576..70b3c5b4b968 100755 --- a/tests/lint/git-clang-format.sh +++ b/tests/lint/git-clang-format.sh @@ -77,16 +77,16 @@ ${CLANG_FORMAT} --version if [[ "$INPLACE_FORMAT" == "true" ]]; then echo "Running inplace git-clang-format against $REVISION" - git-${CLANG_FORMAT} --extensions h,mm,c,cc --binary=${CLANG_FORMAT} "$REVISION" + git-${CLANG_FORMAT} --extensions h,mm,c,cc,cu --binary=${CLANG_FORMAT} "$REVISION" exit 0 fi if [[ "$LINT_ALL_FILES" == "true" ]]; then echo "Running git-clang-format against all C++ files" - git-${CLANG_FORMAT} --diff --extensions h,mm,c,cc --binary=${CLANG_FORMAT} "$REVISION" 1> /tmp/$$.clang-format.txt + git-${CLANG_FORMAT} --diff --extensions h,mm,c,cc,cu --binary=${CLANG_FORMAT} "$REVISION" 1> /tmp/$$.clang-format.txt else echo "Running git-clang-format against $REVISION" - git-${CLANG_FORMAT} --diff --extensions h,mm,c,cc --binary=${CLANG_FORMAT} "$REVISION" 1> /tmp/$$.clang-format.txt + git-${CLANG_FORMAT} --diff --extensions h,mm,c,cc,cu --binary=${CLANG_FORMAT} "$REVISION" 1> /tmp/$$.clang-format.txt fi if grep --quiet -E "diff" < /tmp/$$.clang-format.txt; then diff --git a/tests/python/relax/test_vm_callback_function.py b/tests/python/relax/test_vm_callback_function.py index 1cee0b57d801..c8f3f2945ede 100644 --- a/tests/python/relax/test_vm_callback_function.py +++ b/tests/python/relax/test_vm_callback_function.py @@ -100,6 +100,7 @@ def relax_func( ) vm = tvm.relax.VirtualMachine(ex, dev) + # custom callback that raises an error in python def custom_callback(): local_var = 42 raise RuntimeError("Error thrown from callback") diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc index 922b25b0d74b..8fecaad19d1a 100644 --- a/web/emcc/tvmjs_support.cc +++ b/web/emcc/tvmjs_support.cc @@ -32,6 +32,7 @@ #define DMLC_USE_LOGGING_LIBRARY #include +#include #include #include "../../src/runtime/rpc/rpc_local_session.h" @@ -301,8 +302,11 @@ class AsyncLocalSession : public LocalSession { } }; -TVM_FFI_REGISTER_GLOBAL("wasm.LocalSession").set_body_typed([]() { - return CreateRPCSessionModule(std::make_shared()); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("wasm.LocalSession", []() { + return CreateRPCSessionModule(std::make_shared()); + }); }); } // namespace runtime diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index 3c1c15a86123..22c38effef65 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -30,6 +30,7 @@ #define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 1 #define DMLC_USE_LOGGING_LIBRARY +#include #include #include "src/runtime/contrib/sort/sort.cc" @@ -104,23 +105,21 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s } // namespace detail -TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.call") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - (args[0].cast()).CallPacked(args.Slice(1), ret); - }); - -TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.log_info_str") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - LOG(INFO) << args[0].cast(); - }); - -TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.add_one").set_body_typed([](int x) { return x + 1; }); - -TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.wrap_callback") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - ffi::Function pf = args[0].cast(); - *ret = ffi::TypedFunction([pf]() { pf(); }); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("tvmjs.testing.call", + [](ffi::PackedArgs args, ffi::Any* ret) { + (args[0].cast()).CallPacked(args.Slice(1), ret); + }) + .def_packed("tvmjs.testing.log_info_str", + [](ffi::PackedArgs args, ffi::Any* ret) { LOG(INFO) << args[0].cast(); }) + .def("tvmjs.testing.add_one", [](int x) { return x + 1; }) + .def_packed("tvmjs.testing.wrap_callback", [](ffi::PackedArgs args, ffi::Any* ret) { + ffi::Function pf = args[0].cast(); + *ret = ffi::TypedFunction([pf]() { pf(); }); + }); +}); void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string format, std::string dtype) { if (format == "f32-to-bf16" && dtype == "float32") { @@ -143,23 +142,29 @@ void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string format, } } -TVM_FFI_REGISTER_GLOBAL("tvmjs.array.decode_storage").set_body_typed(ArrayDecodeStorage); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tvmjs.array.decode_storage", ArrayDecodeStorage); +}); // Concatenate n TVMArrays -TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.ArrayConcat") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - std::vector data; - for (int i = 0; i < args.size(); ++i) { - // Get i-th TVMArray - auto* arr_i = args[i].as(); - ICHECK(arr_i != nullptr); - for (size_t j = 0; j < arr_i->size(); ++j) { - // Push back each j-th element of the i-th array - data.push_back(arr_i->at(j)); - } - } - *ret = Array(data); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("tvmjs.runtime.ArrayConcat", + [](ffi::PackedArgs args, ffi::Any* ret) { + std::vector data; + for (int i = 0; i < args.size(); ++i) { + // Get i-th TVMArray + auto* arr_i = args[i].as(); + ICHECK(arr_i != nullptr); + for (size_t j = 0; j < arr_i->size(); ++j) { + // Push back each j-th element of the i-th array + data.push_back(arr_i->at(j)); + } + } + *ret = Array(data); + }); +}); NDArray ConcatEmbeddings(const std::vector& embeddings) { // Get output shape @@ -196,29 +201,28 @@ NDArray ConcatEmbeddings(const std::vector& embeddings) { } // Concatenate n NDArrays -TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - std::vector embeddings; - for (int i = 0; i < args.size(); ++i) { - embeddings.push_back(args[i].cast()); - } - NDArray result = ConcatEmbeddings(std::move(embeddings)); - *ret = result; - }); - -TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.NDArrayCopyFromBytes") - .set_body_typed([](NDArray nd, TVMFFIByteArray* bytes) { - nd.CopyFromBytes(bytes->data, bytes->size); - }); - -TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.NDArrayCopyToBytes") - .set_body_typed([](NDArray nd) -> ffi::Bytes { - size_t size = GetDataSize(*(nd.operator->())); - std::string bytes; - bytes.resize(size); - nd.CopyToBytes(bytes.data(), size); - return ffi::Bytes(bytes); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("tvmjs.runtime.ConcatEmbeddings", + [](ffi::PackedArgs args, ffi::Any* ret) { + std::vector embeddings; + for (int i = 0; i < args.size(); ++i) { + embeddings.push_back(args[i].cast()); + } + NDArray result = ConcatEmbeddings(std::move(embeddings)); + *ret = result; + }) + .def("tvmjs.runtime.NDArrayCopyFromBytes", + [](NDArray nd, TVMFFIByteArray* bytes) { nd.CopyFromBytes(bytes->data, bytes->size); }) + .def("tvmjs.runtime.NDArrayCopyToBytes", [](NDArray nd) -> ffi::Bytes { + size_t size = GetDataSize(*(nd.operator->())); + std::string bytes; + bytes.resize(size); + nd.CopyToBytes(bytes.data(), size); + return ffi::Bytes(bytes); + }); +}); } // namespace runtime } // namespace tvm diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index 00b1db266a0b..72ea65ca858d 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -30,6 +30,7 @@ #define DMLC_USE_LOGGING_LIBRARY #include +#include #include #include @@ -242,13 +243,15 @@ Module WebGPUModuleLoadBinary(void* strm) { } // for now webgpu is hosted via a vulkan module. -TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_webgpu").set_body_typed(WebGPUModuleLoadBinary); - -TVM_FFI_REGISTER_GLOBAL("device_api.webgpu") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = WebGPUDeviceAPI::Global(); - *rv = static_cast(ptr); - }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.module.loadbinary_webgpu", WebGPUModuleLoadBinary) + .def_packed("device_api.webgpu", [](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = WebGPUDeviceAPI::Global(); + *rv = static_cast(ptr); + }); +}); } // namespace runtime } // namespace tvm