@@ -15,67 +15,67 @@ template <typename ADataType,
1515 typename CDataType,
1616 typename ComputeDataType,
1717 typename ElementwiseFunctor,
18- index_t Dim ,
19- index_t M0PerThread ,
18+ index_t NDim ,
19+ index_t MPerThread ,
2020 index_t AScalarPerVector,
2121 index_t BScalarPerVector,
2222 index_t CScalarPerVector>
2323struct DeviceBinaryElementwise : public BaseOperator
2424{
2525 static constexpr auto I0 = Number<0 >{};
2626
27- template <typename Desc_M0 >
28- static auto PadDescriptor_M0_1d (Desc_M0 desc_m0 , index_t gridSize, index_t blockSize)
27+ template <typename Desc_M >
28+ static auto PadDescriptor_M_1d (Desc_M desc_m , index_t gridSize, index_t blockSize)
2929 {
30- const auto m0 = desc_m0 .GetLength (I0);
31- const index_t loop_step = gridSize * blockSize * M0PerThread ;
32- const auto pad = math::integer_least_multiple (m0 , loop_step) - m0 ;
33- const auto desc_m0_pad =
34- transform_tensor_descriptor (desc_m0 ,
35- make_tuple (make_right_pad_transform (m0 , pad)),
30+ const auto m = desc_m .GetLength (I0);
31+ const index_t loop_step = gridSize * blockSize * MPerThread ;
32+ const auto pad = math::integer_least_multiple (m , loop_step) - m ;
33+ const auto desc_m_pad =
34+ transform_tensor_descriptor (desc_m ,
35+ make_tuple (make_right_pad_transform (m , pad)),
3636 make_tuple (Sequence<0 >{}),
3737 make_tuple (Sequence<0 >{}));
38- return desc_m0_pad ;
38+ return desc_m_pad ;
3939 }
4040
41- static auto MakeDescriptor_M0 (const std::vector<index_t >& lengths,
42- const std::vector<index_t >& strides,
43- index_t gridSize,
44- index_t blockSize)
41+ static auto MakeDescriptor_M (const std::vector<index_t >& lengths,
42+ const std::vector<index_t >& strides,
43+ index_t gridSize,
44+ index_t blockSize)
4545 {
46- auto tupleOfShape = generate_tuple ([&](auto I) { return lengths[I]; }, Number<Dim >{});
47- auto tupleOfStride = generate_tuple ([&](auto I) { return strides[I]; }, Number<Dim >{});
46+ auto tupleOfShape = generate_tuple ([&](auto I) { return lengths[I]; }, Number<NDim >{});
47+ auto tupleOfStride = generate_tuple ([&](auto I) { return strides[I]; }, Number<NDim >{});
4848
4949 // nd desc - [s0, s1, s2, ...]
5050 const auto desc = make_naive_tensor_descriptor (tupleOfShape, tupleOfStride);
5151
5252 // merge nd to 1d desc - [s0 * s1 * ...]
53- if constexpr (Dim > 1 )
53+ if constexpr (NDim > 1 )
5454 {
55- const auto desc_m0 = transform_tensor_descriptor (
55+ const auto desc_m = transform_tensor_descriptor (
5656 desc,
5757 make_tuple (make_merge_transform (tupleOfShape)),
58- make_tuple (generate_sequence_v2 ([&](auto I) { return I; }, Number<Dim >{})),
58+ make_tuple (generate_sequence_v2 ([&](auto I) { return I; }, Number<NDim >{})),
5959 make_tuple (Sequence<0 >{}));
6060
61- return PadDescriptor_M0_1d (desc_m0 , gridSize, blockSize);
61+ return PadDescriptor_M_1d (desc_m , gridSize, blockSize);
6262 }
6363 else
64- return PadDescriptor_M0_1d (desc, gridSize, blockSize);
64+ return PadDescriptor_M_1d (desc, gridSize, blockSize);
6565 }
6666
67- using AGridDesc_M0 = decltype (MakeDescriptor_M0 ({1 , 1 }, {1 , 1 }, 1 , 1 ));
68- using BGridDesc_M0 = decltype (MakeDescriptor_M0 ({1 , 1 }, {1 , 1 }, 1 , 1 ));
69- using CGridDesc_M0 = decltype (MakeDescriptor_M0 ({1 , 1 }, {1 , 1 }, 1 , 1 ));
67+ using AGridDesc_M = decltype (MakeDescriptor_M ({1 , 1 }, {1 , 1 }, 1 , 1 ));
68+ using BGridDesc_M = decltype (MakeDescriptor_M ({1 , 1 }, {1 , 1 }, 1 , 1 ));
69+ using CGridDesc_M = decltype (MakeDescriptor_M ({1 , 1 }, {1 , 1 }, 1 , 1 ));
7070 using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
7171 BDataType,
7272 CDataType,
7373 ComputeDataType,
74- AGridDesc_M0 ,
75- BGridDesc_M0 ,
76- CGridDesc_M0 ,
74+ AGridDesc_M ,
75+ BGridDesc_M ,
76+ CGridDesc_M ,
7777 ElementwiseFunctor,
78- M0PerThread ,
78+ MPerThread ,
7979 AScalarPerVector,
8080 BScalarPerVector,
8181 CScalarPerVector>;
@@ -101,18 +101,18 @@ struct DeviceBinaryElementwise : public BaseOperator
101101 blockSize_(256 ),
102102 gridSize_(120 ) // FIXME - Calculate the grid size by number of CU in the future
103103 {
104- a_grid_desc_m0_ = MakeDescriptor_M0 (lengths, a_strides, gridSize_, blockSize_);
105- b_grid_desc_m0_ = MakeDescriptor_M0 (lengths, b_strides, gridSize_, blockSize_);
106- c_grid_desc_m0_ = MakeDescriptor_M0 (lengths, c_strides, gridSize_, blockSize_);
104+ a_grid_desc_m_ = MakeDescriptor_M (lengths, a_strides, gridSize_, blockSize_);
105+ b_grid_desc_m_ = MakeDescriptor_M (lengths, b_strides, gridSize_, blockSize_);
106+ c_grid_desc_m_ = MakeDescriptor_M (lengths, c_strides, gridSize_, blockSize_);
107107 }
108108
109109 const ADataType* p_a_;
110110 const BDataType* p_b_;
111111 CDataType* p_c_;
112112 std::vector<int > lengths_;
113- AGridDesc_M0 a_grid_desc_m0_ ;
114- BGridDesc_M0 b_grid_desc_m0_ ;
115- CGridDesc_M0 c_grid_desc_m0_ ;
113+ AGridDesc_M a_grid_desc_m_ ;
114+ BGridDesc_M b_grid_desc_m_ ;
115+ CGridDesc_M c_grid_desc_m_ ;
116116 std::vector<index_t > a_strides_;
117117 std::vector<index_t > b_strides_;
118118 std::vector<index_t > c_strides_;
@@ -129,9 +129,9 @@ struct DeviceBinaryElementwise : public BaseOperator
129129 ADataType,
130130 BDataType,
131131 CDataType,
132- AGridDesc_M0 ,
133- BGridDesc_M0 ,
134- CGridDesc_M0 ,
132+ AGridDesc_M ,
133+ BGridDesc_M ,
134+ CGridDesc_M ,
135135 ElementwiseFunctor>;
136136
137137 float elapsed_time = launch_and_time_kernel (stream_config,
@@ -142,9 +142,9 @@ struct DeviceBinaryElementwise : public BaseOperator
142142 arg.p_a_ ,
143143 arg.p_b_ ,
144144 arg.p_c_ ,
145- arg.a_grid_desc_m0_ ,
146- arg.b_grid_desc_m0_ ,
147- arg.c_grid_desc_m0_ ,
145+ arg.a_grid_desc_m_ ,
146+ arg.b_grid_desc_m_ ,
147+ arg.c_grid_desc_m_ ,
148148 arg.functor_ );
149149 return elapsed_time;
150150 }
@@ -164,19 +164,19 @@ struct DeviceBinaryElementwise : public BaseOperator
164164 if (pArg == nullptr )
165165 return false ;
166166
167- if (pArg->lengths_ .size () != Dim )
167+ if (pArg->lengths_ .size () != NDim )
168168 return false ;
169169
170- if (pArg->lengths_ .back () % M0PerThread != 0 )
170+ if (pArg->lengths_ .back () % MPerThread != 0 )
171171 return false ;
172172
173- auto IsScalarPerVectorValid = [](bool isFastestAxisCoalesce , int scalarPerVector) {
173+ auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced , int scalarPerVector) {
174174 bool ret = true ;
175175
176- if (!isFastestAxisCoalesce )
176+ if (!isLastDimensionCoalesced )
177177 ret = scalarPerVector == 1 ;
178178 else
179- ret = M0PerThread % scalarPerVector == 0 ;
179+ ret = MPerThread % scalarPerVector == 0 ;
180180
181181 return ret;
182182 };
@@ -221,7 +221,7 @@ struct DeviceBinaryElementwise : public BaseOperator
221221 // clang-format off
222222 str << " DeviceBinaryElementwise"
223223 << " <"
224- << " M0PerThread = " << M0PerThread
224+ << " MPerThread = " << MPerThread
225225 << " >" ;
226226 // clang-format on
227227
0 commit comments