@@ -11,14 +11,16 @@ template <typename GridwiseBinEltwise,
1111 typename ADataType,
1212 typename BDataType,
1313 typename CDataType,
14- typename GridDesc_M0,
14+ typename AGridDesc_M0,
15+ typename BGridDesc_M0,
16+ typename CGridDesc_M0,
1517 typename ElementwiseFunctor>
1618__global__ void kernel_binary_elementwise_1d (const ADataType* __restrict__ p_a_global,
1719 const BDataType* __restrict__ p_b_global,
1820 CDataType* __restrict__ p_c_global,
19- const GridDesc_M0 a_grid_desc_m0,
20- const GridDesc_M0 b_grid_desc_m0,
21- const GridDesc_M0 c_grid_desc_m0,
21+ const AGridDesc_M0 a_grid_desc_m0,
22+ const BGridDesc_M0 b_grid_desc_m0,
23+ const CGridDesc_M0 c_grid_desc_m0,
2224 const ElementwiseFunctor functor)
2325{
2426 GridwiseBinEltwise::Run (p_a_global,
@@ -34,7 +36,9 @@ template <typename ADataType,
3436 typename BDataType,
3537 typename CDataType,
3638 typename ComputeDataType,
37- typename GridDesc_M0,
39+ typename AGridDesc_M0,
40+ typename BGridDesc_M0,
41+ typename CGridDesc_M0,
3842 typename ElementwiseFunctor,
3943 index_t M0PerThread,
4044 index_t AScalarPerVector,
@@ -57,9 +61,9 @@ struct GridwiseBinaryElementwise_1D
5761 __device__ static void Run (const ADataType* __restrict__ p_a_global,
5862 const BDataType* __restrict__ p_b_global,
5963 CDataType* __restrict__ p_c_global,
60- const GridDesc_M0 a_grid_desc_m0,
61- const GridDesc_M0 b_grid_desc_m0,
62- const GridDesc_M0 c_grid_desc_m0,
64+ const AGridDesc_M0 a_grid_desc_m0,
65+ const BGridDesc_M0 b_grid_desc_m0,
66+ const CGridDesc_M0 c_grid_desc_m0,
6367 const ElementwiseFunctor functor)
6468 {
6569 const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
@@ -78,7 +82,7 @@ struct GridwiseBinaryElementwise_1D
7882 auto a_global_load =
7983 ThreadwiseTensorSliceTransfer_v2<ADataType,
8084 ComputeDataType,
81- GridDesc_M0 ,
85+ AGridDesc_M0 ,
8286 decltype (thread_desc_m0),
8387 Sequence<M0PerThread>, // SliceLengths
8488 Sequence<0 >, // DimAccessOrder
@@ -90,7 +94,7 @@ struct GridwiseBinaryElementwise_1D
9094 auto b_global_load =
9195 ThreadwiseTensorSliceTransfer_v2<BDataType,
9296 ComputeDataType,
93- GridDesc_M0 ,
97+ BGridDesc_M0 ,
9498 decltype (thread_desc_m0),
9599 Sequence<M0PerThread>, // SliceLengths
96100 Sequence<0 >, // DimAccessOrder
@@ -103,7 +107,7 @@ struct GridwiseBinaryElementwise_1D
103107 ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
104108 CDataType,
105109 decltype (thread_desc_m0),
106- GridDesc_M0 ,
110+ CGridDesc_M0 ,
107111 PassThrough,
108112 Sequence<M0PerThread>, // SliceLengths
109113 Sequence<0 >, // DimAccessOrder
0 commit comments