Skip to content

Commit 221146a

Browse files
author
rocking
committed
rename var
1 parent 5912fa0 commit 221146a

File tree

2 files changed

+107
-112
lines changed

2 files changed

+107
-112
lines changed

include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -15,67 +15,67 @@ template <typename ADataType,
1515
typename CDataType,
1616
typename ComputeDataType,
1717
typename ElementwiseFunctor,
18-
index_t Dim,
19-
index_t M0PerThread,
18+
index_t NDim,
19+
index_t MPerThread,
2020
index_t AScalarPerVector,
2121
index_t BScalarPerVector,
2222
index_t CScalarPerVector>
2323
struct DeviceBinaryElementwise : public BaseOperator
2424
{
2525
static constexpr auto I0 = Number<0>{};
2626

27-
template <typename Desc_M0>
28-
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize)
27+
template <typename Desc_M>
28+
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
2929
{
30-
const auto m0 = desc_m0.GetLength(I0);
31-
const index_t loop_step = gridSize * blockSize * M0PerThread;
32-
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
33-
const auto desc_m0_pad =
34-
transform_tensor_descriptor(desc_m0,
35-
make_tuple(make_right_pad_transform(m0, pad)),
30+
const auto m = desc_m.GetLength(I0);
31+
const index_t loop_step = gridSize * blockSize * MPerThread;
32+
const auto pad = math::integer_least_multiple(m, loop_step) - m;
33+
const auto desc_m_pad =
34+
transform_tensor_descriptor(desc_m,
35+
make_tuple(make_right_pad_transform(m, pad)),
3636
make_tuple(Sequence<0>{}),
3737
make_tuple(Sequence<0>{}));
38-
return desc_m0_pad;
38+
return desc_m_pad;
3939
}
4040

41-
static auto MakeDescriptor_M0(const std::vector<index_t>& lengths,
42-
const std::vector<index_t>& strides,
43-
index_t gridSize,
44-
index_t blockSize)
41+
static auto MakeDescriptor_M(const std::vector<index_t>& lengths,
42+
const std::vector<index_t>& strides,
43+
index_t gridSize,
44+
index_t blockSize)
4545
{
46-
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<Dim>{});
47-
auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<Dim>{});
46+
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NDim>{});
47+
auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<NDim>{});
4848

4949
// nd desc - [s0, s1, s2, ...]
5050
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
5151

5252
// merge nd to 1d desc - [s0 * s1 * ...]
53-
if constexpr(Dim > 1)
53+
if constexpr(NDim > 1)
5454
{
55-
const auto desc_m0 = transform_tensor_descriptor(
55+
const auto desc_m = transform_tensor_descriptor(
5656
desc,
5757
make_tuple(make_merge_transform(tupleOfShape)),
58-
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<Dim>{})),
58+
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NDim>{})),
5959
make_tuple(Sequence<0>{}));
6060

61-
return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize);
61+
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
6262
}
6363
else
64-
return PadDescriptor_M0_1d(desc, gridSize, blockSize);
64+
return PadDescriptor_M_1d(desc, gridSize, blockSize);
6565
}
6666

67-
using AGridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
68-
using BGridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
69-
using CGridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
67+
using AGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
68+
using BGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
69+
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
7070
using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
7171
BDataType,
7272
CDataType,
7373
ComputeDataType,
74-
AGridDesc_M0,
75-
BGridDesc_M0,
76-
CGridDesc_M0,
74+
AGridDesc_M,
75+
BGridDesc_M,
76+
CGridDesc_M,
7777
ElementwiseFunctor,
78-
M0PerThread,
78+
MPerThread,
7979
AScalarPerVector,
8080
BScalarPerVector,
8181
CScalarPerVector>;
@@ -101,18 +101,18 @@ struct DeviceBinaryElementwise : public BaseOperator
101101
blockSize_(256),
102102
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
103103
{
104-
a_grid_desc_m0_ = MakeDescriptor_M0(lengths, a_strides, gridSize_, blockSize_);
105-
b_grid_desc_m0_ = MakeDescriptor_M0(lengths, b_strides, gridSize_, blockSize_);
106-
c_grid_desc_m0_ = MakeDescriptor_M0(lengths, c_strides, gridSize_, blockSize_);
104+
a_grid_desc_m_ = MakeDescriptor_M(lengths, a_strides, gridSize_, blockSize_);
105+
b_grid_desc_m_ = MakeDescriptor_M(lengths, b_strides, gridSize_, blockSize_);
106+
c_grid_desc_m_ = MakeDescriptor_M(lengths, c_strides, gridSize_, blockSize_);
107107
}
108108

109109
const ADataType* p_a_;
110110
const BDataType* p_b_;
111111
CDataType* p_c_;
112112
std::vector<int> lengths_;
113-
AGridDesc_M0 a_grid_desc_m0_;
114-
BGridDesc_M0 b_grid_desc_m0_;
115-
CGridDesc_M0 c_grid_desc_m0_;
113+
AGridDesc_M a_grid_desc_m_;
114+
BGridDesc_M b_grid_desc_m_;
115+
CGridDesc_M c_grid_desc_m_;
116116
std::vector<index_t> a_strides_;
117117
std::vector<index_t> b_strides_;
118118
std::vector<index_t> c_strides_;
@@ -129,9 +129,9 @@ struct DeviceBinaryElementwise : public BaseOperator
129129
ADataType,
130130
BDataType,
131131
CDataType,
132-
AGridDesc_M0,
133-
BGridDesc_M0,
134-
CGridDesc_M0,
132+
AGridDesc_M,
133+
BGridDesc_M,
134+
CGridDesc_M,
135135
ElementwiseFunctor>;
136136

137137
float elapsed_time = launch_and_time_kernel(stream_config,
@@ -142,9 +142,9 @@ struct DeviceBinaryElementwise : public BaseOperator
142142
arg.p_a_,
143143
arg.p_b_,
144144
arg.p_c_,
145-
arg.a_grid_desc_m0_,
146-
arg.b_grid_desc_m0_,
147-
arg.c_grid_desc_m0_,
145+
arg.a_grid_desc_m_,
146+
arg.b_grid_desc_m_,
147+
arg.c_grid_desc_m_,
148148
arg.functor_);
149149
return elapsed_time;
150150
}
@@ -164,19 +164,19 @@ struct DeviceBinaryElementwise : public BaseOperator
164164
if(pArg == nullptr)
165165
return false;
166166

167-
if(pArg->lengths_.size() != Dim)
167+
if(pArg->lengths_.size() != NDim)
168168
return false;
169169

170-
if(pArg->lengths_.back() % M0PerThread != 0)
170+
if(pArg->lengths_.back() % MPerThread != 0)
171171
return false;
172172

173-
auto IsScalarPerVectorValid = [](bool isFastestAxisCoalesce, int scalarPerVector) {
173+
auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) {
174174
bool ret = true;
175175

176-
if(!isFastestAxisCoalesce)
176+
if(!isLastDimensionCoalesced)
177177
ret = scalarPerVector == 1;
178178
else
179-
ret = M0PerThread % scalarPerVector == 0;
179+
ret = MPerThread % scalarPerVector == 0;
180180

181181
return ret;
182182
};
@@ -221,7 +221,7 @@ struct DeviceBinaryElementwise : public BaseOperator
221221
// clang-format off
222222
str << "DeviceBinaryElementwise"
223223
<< "<"
224-
<< "M0PerThread = " << M0PerThread
224+
<< "MPerThread = " << MPerThread
225225
<< ">";
226226
// clang-format on
227227

0 commit comments

Comments
 (0)