11#include " simple_binary.hh"
22#include " common.h"
33#include " computation/operators/simple_binary.h"
4+ #include " kernel/collectors/simple_binary.h"
5+ #include " runtime/resource.h"
6+ #include < execution>
47
58namespace refactor ::onnx {
69 using Op = SimpleBinary;
@@ -10,7 +13,7 @@ namespace refactor::onnx {
1013 : Operator(), type(type_) {}
1114
1215 auto Op::build (ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox {
13- auto fmod = attributes.getOrInsert ( " fmod" , {0 }).int_ ();
16+ auto fmod = attributes.getOrInsert (" fmod" , {0 }).int_ ();
1417 // clang-format off
1518 auto type =
1619 opType == " onnx::Add" ? Ty::Add :
@@ -93,30 +96,6 @@ namespace refactor::onnx {
9396 // clang-format on
9497 }
9598
96- template <decltype (DataType::internal) T>
97- void calculate (Ty ty, void *dst, void const *a, void const *b) {
98- using T_ = typename primitive<T>::type;
99- auto a_ = *reinterpret_cast <T_ const *>(a);
100- auto b_ = *reinterpret_cast <T_ const *>(b);
101- auto dst_ = reinterpret_cast <T_ *>(dst);
102- switch (ty) {
103- case Ty::Add:
104- *dst_ = a_ + b_;
105- break ;
106- case Ty::Sub:
107- *dst_ = a_ - b_;
108- break ;
109- case Ty::Mul:
110- *dst_ = a_ * b_;
111- break ;
112- case Ty::Div:
113- *dst_ = a_ / b_;
114- break ;
115- default :
116- UNREACHABLE ();
117- }
118- }
119-
12099 auto Op::infer (TensorRefs inputs, InferOptions const &options) const -> InferResult {
121100 EXPECT_SIZE (2 )
122101
@@ -139,35 +118,36 @@ namespace refactor::onnx {
139118 return Ok (Tensors{std::move (ans)});
140119 }
141120
142- auto eleSize = dataType. size ();
143- auto dst = reinterpret_cast < uint8_t *>(ans-> malloc ()) ;
144- for ( auto i : range0_ (ans-> elementsSize ())) {
145- auto indices = locateN (ans-> shape , i) ;
146- auto a_ = locate1 (a, indices),
147- b_ = locate1 (b, indices );
148- auto dst_ = dst + i * eleSize ;
149- // -------------------------------------
150- # define CASE ( T ) \
151- case DataType::T: \
152- calculate<DataType::T>(type, dst_, a_, b_); \
153- break
154- // -------------------------------------
155- switch (dataType. internal ) {
156- CASE (F32 );
157- CASE (F64 );
158- CASE (I32);
159- CASE (I64);
160- CASE (I8 );
161- CASE (I16 );
162- CASE (U8) ;
163- CASE (U16 );
164- CASE (U32 );
165- CASE (U64) ;
166- default :
167- ans-> free () ;
168- break ;
169- }
121+ {
122+ using Shape = kernel::Shape ;
123+ using Tensor = kernel::Tensor;
124+ using LayoutType = kernel::LayoutType ;
125+
126+ Shape t1Shape (a. shape . size (), 1 );
127+ Shape t2Shape (b. shape . size (), 1 ) ;
128+ Shape oShape (ans-> shape . size (), 1 );
129+ std::transform (std::execution::unseq,
130+ a. shape . begin (), a. shape . end (), t1Shape. begin (),
131+ []( auto const &i) { return static_cast < dim_t >(i. value ()); });
132+ std::transform (std::execution::unseq,
133+ b. shape . begin (), b. shape . end (), t2Shape. begin (),
134+ []( auto const &i ) { return static_cast < dim_t >(i. value ()); });
135+ auto t1 = Tensor::share (a. dataType , t1Shape, LayoutType::Others, a. data );
136+ auto t2 = Tensor::share (b. dataType , t2Shape, LayoutType::Others, b. data );
137+ std::transform (std::execution::unseq,
138+ ans-> shape . begin (), ans-> shape . end (), oShape. begin (),
139+ []( auto const &i) { return static_cast < dim_t >(i. value ()); } );
140+ auto o = Tensor::share (a. dataType , oShape, LayoutType::Others );
141+ runtime::Resources res ;
142+ auto type_ = static_cast <kernel::SimpleBinaryType>(type );
143+ const auto collector = kernel::SimpleBinaryCollector (computation::Target::Cpu, type_ );
144+ auto routine = std::move (collector. filter ({*t1, *t2}, {*o}). at ( 0 ))-> lower (res). routine ;
145+ void const *inputsCpu[]{*t1-> data , *t2-> data };
146+ void *outputsCpu[]{o-> malloc ()} ;
147+ routine (res, nullptr , inputsCpu, outputsCpu) ;
148+ ans-> data = o-> data ;
170149 }
150+
171151 return Ok (Tensors{std::move (ans)});
172152 }
173153
0 commit comments