Skip to content

Commit 5912fa0

Browse files
author
rocking
committed
Separate GridDesc_M0 into A, B and C
1 parent 0334a92 commit 5912fa0

File tree

2 files changed

+27
-17
lines changed

2 files changed

+27
-17
lines changed

include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,16 @@ struct DeviceBinaryElementwise : public BaseOperator
6464
return PadDescriptor_M0_1d(desc, gridSize, blockSize);
6565
}
6666

67-
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
67+
using AGridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
68+
using BGridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
69+
using CGridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
6870
using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
6971
BDataType,
7072
CDataType,
7173
ComputeDataType,
72-
GridDesc_M0,
74+
AGridDesc_M0,
75+
BGridDesc_M0,
76+
CGridDesc_M0,
7377
ElementwiseFunctor,
7478
M0PerThread,
7579
AScalarPerVector,
@@ -106,9 +110,9 @@ struct DeviceBinaryElementwise : public BaseOperator
106110
const BDataType* p_b_;
107111
CDataType* p_c_;
108112
std::vector<int> lengths_;
109-
GridDesc_M0 a_grid_desc_m0_;
110-
GridDesc_M0 b_grid_desc_m0_;
111-
GridDesc_M0 c_grid_desc_m0_;
113+
AGridDesc_M0 a_grid_desc_m0_;
114+
BGridDesc_M0 b_grid_desc_m0_;
115+
CGridDesc_M0 c_grid_desc_m0_;
112116
std::vector<index_t> a_strides_;
113117
std::vector<index_t> b_strides_;
114118
std::vector<index_t> c_strides_;
@@ -125,7 +129,9 @@ struct DeviceBinaryElementwise : public BaseOperator
125129
ADataType,
126130
BDataType,
127131
CDataType,
128-
GridDesc_M0,
132+
AGridDesc_M0,
133+
BGridDesc_M0,
134+
CGridDesc_M0,
129135
ElementwiseFunctor>;
130136

131137
float elapsed_time = launch_and_time_kernel(stream_config,

include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@ template <typename GridwiseBinEltwise,
1111
typename ADataType,
1212
typename BDataType,
1313
typename CDataType,
14-
typename GridDesc_M0,
14+
typename AGridDesc_M0,
15+
typename BGridDesc_M0,
16+
typename CGridDesc_M0,
1517
typename ElementwiseFunctor>
1618
__global__ void kernel_binary_elementwise_1d(const ADataType* __restrict__ p_a_global,
1719
const BDataType* __restrict__ p_b_global,
1820
CDataType* __restrict__ p_c_global,
19-
const GridDesc_M0 a_grid_desc_m0,
20-
const GridDesc_M0 b_grid_desc_m0,
21-
const GridDesc_M0 c_grid_desc_m0,
21+
const AGridDesc_M0 a_grid_desc_m0,
22+
const BGridDesc_M0 b_grid_desc_m0,
23+
const CGridDesc_M0 c_grid_desc_m0,
2224
const ElementwiseFunctor functor)
2325
{
2426
GridwiseBinEltwise::Run(p_a_global,
@@ -34,7 +36,9 @@ template <typename ADataType,
3436
typename BDataType,
3537
typename CDataType,
3638
typename ComputeDataType,
37-
typename GridDesc_M0,
39+
typename AGridDesc_M0,
40+
typename BGridDesc_M0,
41+
typename CGridDesc_M0,
3842
typename ElementwiseFunctor,
3943
index_t M0PerThread,
4044
index_t AScalarPerVector,
@@ -57,9 +61,9 @@ struct GridwiseBinaryElementwise_1D
5761
__device__ static void Run(const ADataType* __restrict__ p_a_global,
5862
const BDataType* __restrict__ p_b_global,
5963
CDataType* __restrict__ p_c_global,
60-
const GridDesc_M0 a_grid_desc_m0,
61-
const GridDesc_M0 b_grid_desc_m0,
62-
const GridDesc_M0 c_grid_desc_m0,
64+
const AGridDesc_M0 a_grid_desc_m0,
65+
const BGridDesc_M0 b_grid_desc_m0,
66+
const CGridDesc_M0 c_grid_desc_m0,
6367
const ElementwiseFunctor functor)
6468
{
6569
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
@@ -78,7 +82,7 @@ struct GridwiseBinaryElementwise_1D
7882
auto a_global_load =
7983
ThreadwiseTensorSliceTransfer_v2<ADataType,
8084
ComputeDataType,
81-
GridDesc_M0,
85+
AGridDesc_M0,
8286
decltype(thread_desc_m0),
8387
Sequence<M0PerThread>, // SliceLengths
8488
Sequence<0>, // DimAccessOrder
@@ -90,7 +94,7 @@ struct GridwiseBinaryElementwise_1D
9094
auto b_global_load =
9195
ThreadwiseTensorSliceTransfer_v2<BDataType,
9296
ComputeDataType,
93-
GridDesc_M0,
97+
BGridDesc_M0,
9498
decltype(thread_desc_m0),
9599
Sequence<M0PerThread>, // SliceLengths
96100
Sequence<0>, // DimAccessOrder
@@ -103,7 +107,7 @@ struct GridwiseBinaryElementwise_1D
103107
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
104108
CDataType,
105109
decltype(thread_desc_m0),
106-
GridDesc_M0,
110+
CGridDesc_M0,
107111
PassThrough,
108112
Sequence<M0PerThread>, // SliceLengths
109113
Sequence<0>, // DimAccessOrder

0 commit comments

Comments
 (0)