Skip to content

Commit d076c20

Browse files
authored
Merge pull request #88 from InfiniTensor/fix-torchvision-models
Fix for apply torchvision models
2 parents 27a8ad6 + b3e89a6 commit d076c20

File tree

6 files changed

+123
-93
lines changed

6 files changed

+123
-93
lines changed

src/04kernel/src/attributes/transpose_info.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ namespace refactor::kernel {
3535
}
3636
}
3737
}
38-
if (rank == 0) {
38+
if (rank <= 1) {
3939
dims = {{1, 1}};
4040
blockSize *= blockCount;
4141
blockCount = 1;
@@ -73,6 +73,12 @@ namespace refactor::kernel {
7373
}
7474
perm.resize(rank);
7575
}
76+
if (rank <= 1) {
77+
dims = {{1, 1}};
78+
blockSize *= blockCount;
79+
blockCount = 1;
80+
return;
81+
}
7682
// 合并末尾连续访存
7783
if (perm.back() == rank - 1) {
7884
blockSize *= shape.back();
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include "kernel/attributes/transpose_info.h"
2+
#include <gtest/gtest.h>
3+
4+
using namespace refactor;
5+
using namespace kernel;
6+
7+
TEST(kernel, TransposeInfo) {
8+
{
9+
TransposeInfo info(
10+
DataType::F32,
11+
{1, 2, 3, 2, 1},
12+
{1, 2, 3, 0, 4});
13+
EXPECT_EQ(info.blockSize, 48);
14+
EXPECT_EQ(info.blockCount, 1);
15+
EXPECT_EQ(info.dims.size(), 1);
16+
}
17+
{
18+
TransposeInfo info(
19+
DataType::F32,
20+
{1, 1, 2, 1, 1},
21+
{1, 2, 3, 0, 4});
22+
EXPECT_EQ(info.blockSize, 8);
23+
EXPECT_EQ(info.blockCount, 1);
24+
EXPECT_EQ(info.dims.size(), 1);
25+
}
26+
{
27+
TransposeInfo info(
28+
DataType::F32,
29+
{1, 2, 3, 4, 5},
30+
{2, 3, 1, 0, 4});
31+
EXPECT_EQ(info.blockSize, 20);
32+
EXPECT_EQ(info.blockCount, 24);
33+
EXPECT_EQ(info.dims.size(), 2);
34+
EXPECT_EQ(info.dims[1].strideI, 12);
35+
EXPECT_EQ(info.dims[1].strideO, 1);
36+
EXPECT_EQ(info.dims[0].strideI, 1);
37+
EXPECT_EQ(info.dims[0].strideO, 2);
38+
}
39+
}

src/06frontend/src/graph.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ namespace refactor::frontend {
102102
for (auto i : range0_(inputs.size())) {
103103
auto j = inputs[i];
104104
auto const &input = _internal.edges[j].tensor;
105-
ASSERT(input, "The {}th input of \"{}\" is nullptr", i, _internal.nodes[nodeIdx].name);
105+
ASSERT(input, "The input[{}] of \"{}\" is nullptr", i, _internal.nodes[nodeIdx].name);
106106
auto checked = edgeChanged[2 * j]; // NOTICE `std::vector<bool>::operator[]` 产生常引用!!!
107107
auto changed = edgeChanged[2 * j + 1];// NOTICE `std::vector<bool>::operator[]` 产生常引用!!!
108108
if (!checked) {

src/07onnx/src/operators/gather.cc

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "computation/operators/gather.h"
22
#include "common.h"
33
#include "gather.hh"
4+
#include "kernel/collectors/gather.h"
5+
#include "runtime/resource.h"
46
#include <execution>
57

68
namespace refactor::onnx {
@@ -42,41 +44,34 @@ namespace refactor::onnx {
4244
if (!options.shouldCalculate(inputs, {*ans})) {
4345
return Ok(Tensors{std::move(ans)});
4446
}
47+
{
48+
using Shape = kernel::Shape;
49+
using Tensor = kernel::Tensor;
50+
using LayoutType = kernel::LayoutType;
4551

46-
std::for_each_n(std::execution::unseq, natural_t(0), ans->elementsSize(),
47-
[&data, &indices, &output,
48-
axis_,
49-
q = indices.shape.size(),
50-
ssz = output.size(),
51-
src = data.data->get<uint8_t>(),
52-
dst = reinterpret_cast<uint8_t *>(ans->malloc()),
53-
eleSize = data.dataType.size()](auto const i) {
54-
auto indices_ = locateN(output, i);
55-
int64_t k;
56-
{
57-
size_t ii = 0, mul = 1;
58-
for (auto j : range0_(q).rev()) {
59-
ii += indices_[j] * mul;
60-
mul *= indices.shape[j].value();
61-
}
62-
k = indices.dataType == DataType::I64
63-
? indices.data->get<int64_t>()[ii]
64-
: indices.data->get<int32_t>()[ii];
65-
}
66-
{
67-
size_t ii = 0, mul = 1;
68-
for (auto j : range(static_cast<decltype(q)>(axis_) + q, ssz).rev()) {
69-
ii += indices_[j] * mul;
70-
mul *= data.shape[j - q + 1].value();
71-
}
72-
ii += k * mul;
73-
for (auto j : range0_(axis_).rev()) {
74-
ii += indices_[j] * mul;
75-
mul *= data.shape[j].value();
76-
}
77-
std::memcpy(dst + i * eleSize, src + ii * eleSize, eleSize);
78-
}
79-
});
52+
Shape t1Shape(data.shape.size(), 1);
53+
Shape t2Shape(indices.shape.size(), 1);
54+
Shape oShape(ans->shape.size(), 1);
55+
std::transform(std::execution::unseq,
56+
data.shape.begin(), data.shape.end(), t1Shape.begin(),
57+
[](auto const &i) { return static_cast<dim_t>(i.value()); });
58+
std::transform(std::execution::unseq,
59+
indices.shape.begin(), indices.shape.end(), t2Shape.begin(),
60+
[](auto const &i) { return static_cast<dim_t>(i.value()); });
61+
auto t1 = Tensor::share(data.dataType, t1Shape, LayoutType::Others, data.data);
62+
auto t2 = Tensor::share(indices.dataType, t2Shape, LayoutType::Others, indices.data);
63+
std::transform(std::execution::unseq,
64+
ans->shape.begin(), ans->shape.end(), oShape.begin(),
65+
[](auto const &i) { return static_cast<dim_t>(i.value()); });
66+
auto o = Tensor::share(data.dataType, oShape, LayoutType::Others);
67+
runtime::Resources res;
68+
const auto collector = kernel::GatherCollector(computation::Target::Cpu, axis_);
69+
auto routine = std::move(collector.filter({*t1, *t2}, {*o}).at(0))->lower(res).routine;
70+
void const *inputsCpu[]{*t1->data, *t2->data};
71+
void *outputsCpu[]{o->malloc()};
72+
routine(res, nullptr, inputsCpu, outputsCpu);
73+
ans->data = o->data;
74+
}
8075

8176
return Ok(Tensors{std::move(ans)});
8277
}

src/07onnx/src/operators/reduce.cc

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,23 @@ namespace refactor::onnx {
2020

2121
auto noopWithEmptyAxes = false;
2222
decltype(Op::axes) axes = std::nullopt;
23-
if (opsetVer >= 18) {
24-
noopWithEmptyAxes = attributes.getOrInsert( "noop_with_empty_axes", {0}).int_() != 0;
23+
24+
// 针对ReduceSum做特判
25+
if (opType == "onnx::ReduceSum") {
26+
if (opsetVer >= 13) {
27+
noopWithEmptyAxes = attributes.getOrInsert("noop_with_empty_axes", {0}).int_() != 0;
28+
} else {
29+
axes.emplace(attributes.getOrInsert("axes", {{}}).ints());
30+
}
2531
} else {
26-
axes.emplace(attributes.getOrInsert( "axes", {{}}).ints());
32+
if (opsetVer >= 18) {
33+
noopWithEmptyAxes = attributes.getOrInsert("noop_with_empty_axes", {0}).int_() != 0;
34+
} else {
35+
axes.emplace(attributes.getOrInsert("axes", {{}}).ints());
36+
}
2737
}
2838

29-
auto keepDims = attributes.getOrInsert( "keepdims", {1}).int_();
39+
auto keepDims = attributes.getOrInsert("keepdims", {1}).int_();
3040
Ty ty;
3141
if (opType == "onnx::ReduceMean") {
3242
ty = Ty::Mean;

src/07onnx/src/operators/simple_binary.cc

Lines changed: 33 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#include "simple_binary.hh"
22
#include "common.h"
33
#include "computation/operators/simple_binary.h"
4+
#include "kernel/collectors/simple_binary.h"
5+
#include "runtime/resource.h"
6+
#include <execution>
47

58
namespace refactor::onnx {
69
using Op = SimpleBinary;
@@ -10,7 +13,7 @@ namespace refactor::onnx {
1013
: Operator(), type(type_) {}
1114

1215
auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox {
13-
auto fmod = attributes.getOrInsert( "fmod", {0}).int_();
16+
auto fmod = attributes.getOrInsert("fmod", {0}).int_();
1417
// clang-format off
1518
auto type =
1619
opType == "onnx::Add" ? Ty::Add :
@@ -93,30 +96,6 @@ namespace refactor::onnx {
9396
// clang-format on
9497
}
9598

96-
template<decltype(DataType::internal) T>
97-
void calculate(Ty ty, void *dst, void const *a, void const *b) {
98-
using T_ = typename primitive<T>::type;
99-
auto a_ = *reinterpret_cast<T_ const *>(a);
100-
auto b_ = *reinterpret_cast<T_ const *>(b);
101-
auto dst_ = reinterpret_cast<T_ *>(dst);
102-
switch (ty) {
103-
case Ty::Add:
104-
*dst_ = a_ + b_;
105-
break;
106-
case Ty::Sub:
107-
*dst_ = a_ - b_;
108-
break;
109-
case Ty::Mul:
110-
*dst_ = a_ * b_;
111-
break;
112-
case Ty::Div:
113-
*dst_ = a_ / b_;
114-
break;
115-
default:
116-
UNREACHABLE();
117-
}
118-
}
119-
12099
auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult {
121100
EXPECT_SIZE(2)
122101

@@ -139,35 +118,36 @@ namespace refactor::onnx {
139118
return Ok(Tensors{std::move(ans)});
140119
}
141120

142-
auto eleSize = dataType.size();
143-
auto dst = reinterpret_cast<uint8_t *>(ans->malloc());
144-
for (auto i : range0_(ans->elementsSize())) {
145-
auto indices = locateN(ans->shape, i);
146-
auto a_ = locate1(a, indices),
147-
b_ = locate1(b, indices);
148-
auto dst_ = dst + i * eleSize;
149-
//-------------------------------------
150-
#define CASE(T) \
151-
case DataType::T: \
152-
calculate<DataType::T>(type, dst_, a_, b_); \
153-
break
154-
//-------------------------------------
155-
switch (dataType.internal) {
156-
CASE(F32);
157-
CASE(F64);
158-
CASE(I32);
159-
CASE(I64);
160-
CASE(I8);
161-
CASE(I16);
162-
CASE(U8);
163-
CASE(U16);
164-
CASE(U32);
165-
CASE(U64);
166-
default:
167-
ans->free();
168-
break;
169-
}
121+
{
122+
using Shape = kernel::Shape;
123+
using Tensor = kernel::Tensor;
124+
using LayoutType = kernel::LayoutType;
125+
126+
Shape t1Shape(a.shape.size(), 1);
127+
Shape t2Shape(b.shape.size(), 1);
128+
Shape oShape(ans->shape.size(), 1);
129+
std::transform(std::execution::unseq,
130+
a.shape.begin(), a.shape.end(), t1Shape.begin(),
131+
[](auto const &i) { return static_cast<dim_t>(i.value()); });
132+
std::transform(std::execution::unseq,
133+
b.shape.begin(), b.shape.end(), t2Shape.begin(),
134+
[](auto const &i) { return static_cast<dim_t>(i.value()); });
135+
auto t1 = Tensor::share(a.dataType, t1Shape, LayoutType::Others, a.data);
136+
auto t2 = Tensor::share(b.dataType, t2Shape, LayoutType::Others, b.data);
137+
std::transform(std::execution::unseq,
138+
ans->shape.begin(), ans->shape.end(), oShape.begin(),
139+
[](auto const &i) { return static_cast<dim_t>(i.value()); });
140+
auto o = Tensor::share(a.dataType, oShape, LayoutType::Others);
141+
runtime::Resources res;
142+
auto type_ = static_cast<kernel::SimpleBinaryType>(type);
143+
const auto collector = kernel::SimpleBinaryCollector(computation::Target::Cpu, type_);
144+
auto routine = std::move(collector.filter({*t1, *t2}, {*o}).at(0))->lower(res).routine;
145+
void const *inputsCpu[]{*t1->data, *t2->data};
146+
void *outputsCpu[]{o->malloc()};
147+
routine(res, nullptr, inputsCpu, outputsCpu);
148+
ans->data = o->data;
170149
}
150+
171151
return Ok(Tensors{std::move(ans)});
172152
}
173153

0 commit comments

Comments
 (0)