Skip to content

Commit fd3d907

Browse files
author
Chao Liu
authored
fix ReLU formula (#61)
* fix relu * clean up * clean up
1 parent 41cdd38 commit fd3d907

File tree

3 files changed

+159
-123
lines changed

3 files changed

+159
-123
lines changed

example/1_gemm_xdl/gemm_xdl.cpp

Lines changed: 89 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -25,115 +25,76 @@ struct PassThrough
2525

2626
struct Relu
2727
{
28-
float alpha = 0.1;
29-
30-
// ReLU
3128
template <typename T>
3229
__host__ __device__ constexpr T operator()(T v) const
3330
{
34-
T tmp = alpha * v;
35-
return tmp > 0 ? tmp : 0;
31+
return v > 0 ? v : 0;
3632
}
3733
};
3834

39-
template <typename ADataType,
40-
typename BDataType,
41-
typename CDataType,
42-
typename ALayout,
43-
typename BLayout,
44-
typename CLayout,
35+
template <ck::index_t... Is>
36+
using S = ck::Sequence<Is...>;
37+
38+
using ADataType = ck::half_t;
39+
using BDataType = ck::half_t;
40+
using CDataType = ck::half_t;
41+
using AccDataType = float;
42+
43+
using ALayout = ck::tensor_layout::gemm::RowMajor;
44+
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
45+
using CLayout = ck::tensor_layout::gemm::RowMajor;
46+
47+
using AOp = PassThrough;
48+
using BOp = PassThrough;
49+
using COp = Relu;
50+
51+
// Compilation parameters for NT problem
52+
// clang-format off
53+
using DeviceGemmInstance =
54+
//#########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
55+
//#########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
56+
//#########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
57+
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
58+
ck::tensor_operation::device::DeviceGemmXdl< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AOp, BOp, COp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>;
59+
// clang-format on
60+
61+
template <typename AType,
62+
typename BType,
63+
typename CType,
4564
typename AElementwiseOperation,
4665
typename BElementwiseOperation,
4766
typename CElementwiseOperation>
48-
struct DeviceGemmInstance;
49-
50-
template <typename AElementwiseOperation,
51-
typename BElementwiseOperation,
52-
typename CElementwiseOperation>
53-
struct DeviceGemmInstance<ck::half_t,
54-
ck::half_t,
55-
ck::half_t,
56-
ck::tensor_layout::gemm::RowMajor,
57-
ck::tensor_layout::gemm::ColumnMajor,
58-
ck::tensor_layout::gemm::RowMajor,
59-
AElementwiseOperation,
60-
BElementwiseOperation,
61-
CElementwiseOperation>
67+
static void host_verify(const Tensor<AType>& a_m_k,
68+
const Tensor<BType>& b_k_n,
69+
Tensor<CType>& c_m_n,
70+
const AElementwiseOperation& a_element_op,
71+
const BElementwiseOperation& b_element_op,
72+
const CElementwiseOperation& c_element_op)
6273
{
63-
using F16 = ck::half_t;
64-
using F32 = float;
65-
66-
using Row = ck::tensor_layout::gemm::RowMajor;
67-
using Col = ck::tensor_layout::gemm::ColumnMajor;
68-
69-
template <ck::index_t... Is>
70-
using S = ck::Sequence<Is...>;
71-
72-
using AOp = AElementwiseOperation;
73-
using BOp = BElementwiseOperation;
74-
using COp = CElementwiseOperation;
75-
76-
// Compilation parameters for NT problem
77-
// clang-format off
78-
using type =
79-
//########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
80-
//########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
81-
//########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
82-
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
83-
ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, AOp, BOp, COp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>;
84-
// clang-format on
85-
};
74+
auto f_mk_kn_mn = [&](auto m, auto n) {
75+
const int K = a_m_k.mDesc.GetLengths()[1];
8676

87-
template <typename AElementwiseOperation,
88-
typename BElementwiseOperation,
89-
typename CElementwiseOperation>
90-
struct DeviceGemmInstance<float,
91-
float,
92-
float,
93-
ck::tensor_layout::gemm::RowMajor,
94-
ck::tensor_layout::gemm::ColumnMajor,
95-
ck::tensor_layout::gemm::RowMajor,
96-
AElementwiseOperation,
97-
BElementwiseOperation,
98-
CElementwiseOperation>
99-
{
100-
using F16 = ck::half_t;
101-
using F32 = float;
102-
103-
using Row = ck::tensor_layout::gemm::RowMajor;
104-
using Col = ck::tensor_layout::gemm::ColumnMajor;
105-
106-
template <ck::index_t... Is>
107-
using S = ck::Sequence<Is...>;
108-
109-
using AOp = AElementwiseOperation;
110-
using BOp = BElementwiseOperation;
111-
using COp = CElementwiseOperation;
112-
113-
// Compilation parameters for NT problem
114-
// clang-format off
115-
using type =
116-
//########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
117-
//########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
118-
//########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
119-
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
120-
ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, AOp, BOp, COp, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>;
121-
// clang-format on
122-
};
77+
double v = 0;
78+
79+
for(int k = 0; k < K; ++k)
80+
{
81+
v += static_cast<const double>(a_element_op(a_m_k(m, k))) *
82+
static_cast<const double>(b_element_op(b_k_n(k, n)));
83+
}
84+
85+
c_m_n(m, n) = c_element_op(v);
86+
};
87+
88+
make_ParallelTensorFunctor(f_mk_kn_mn,
89+
c_m_n.mDesc.GetLengths()[0],
90+
c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency());
91+
}
12392

12493
int main(int argc, char* argv[])
12594
{
126-
if(argc != 4)
127-
{
128-
printf("arg1: verification (0=no, 1=yes)\n");
129-
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
130-
printf("arg3: run kernel # of times (>1)\n");
131-
exit(0);
132-
}
133-
134-
const bool do_verification = std::stoi(argv[1]);
135-
const int init_method = std::stoi(argv[2]);
136-
const int nrepeat = std::stoi(argv[3]);
95+
bool do_verification = 0;
96+
int init_method = 0;
97+
int nrepeat = 5;
13798

13899
// GEMM shape
139100
ck::index_t M = 3840;
@@ -144,15 +105,34 @@ int main(int argc, char* argv[])
144105
ck::index_t StrideB = 4096;
145106
ck::index_t StrideC = 4096;
146107

147-
// matrix data type
148-
using ADataType = ck::half_t;
149-
using BDataType = ck::half_t;
150-
using CDataType = ck::half_t;
108+
if(argc == 4)
109+
{
110+
M = std::stoi(argv[4]);
111+
N = std::stoi(argv[5]);
112+
K = std::stoi(argv[6]);
113+
}
114+
else if(argc == 10)
115+
{
116+
do_verification = std::stoi(argv[1]);
117+
init_method = std::stoi(argv[2]);
118+
nrepeat = std::stoi(argv[3]);
119+
120+
M = std::stoi(argv[4]);
121+
N = std::stoi(argv[5]);
122+
K = std::stoi(argv[6]);
151123

152-
// matrix layout
153-
using ALayout = ck::tensor_layout::gemm::RowMajor;
154-
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
155-
using CLayout = ck::tensor_layout::gemm::RowMajor;
124+
StrideA = std::stoi(argv[7]);
125+
StrideB = std::stoi(argv[8]);
126+
StrideC = std::stoi(argv[9]);
127+
}
128+
else
129+
{
130+
printf("arg1: verification (0=no, 1=yes)\n");
131+
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
132+
printf("arg3: run kernel # of times (>1)\n");
133+
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
134+
exit(0);
135+
}
156136

157137
auto f_host_tensor_descriptor =
158138
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
@@ -198,16 +178,7 @@ int main(int argc, char* argv[])
198178
c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data());
199179

200180
// do GEMM
201-
auto gemm = typename DeviceGemmInstance<ADataType,
202-
BDataType,
203-
CDataType,
204-
ALayout,
205-
BLayout,
206-
CLayout,
207-
PassThrough,
208-
PassThrough,
209-
Relu>::type{};
210-
181+
auto gemm = DeviceGemmInstance{};
211182
auto invoker = gemm.MakeInvoker();
212183
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
213184
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
@@ -218,9 +189,9 @@ int main(int argc, char* argv[])
218189
StrideA,
219190
StrideB,
220191
StrideC,
221-
PassThrough{},
222-
PassThrough{},
223-
Relu{});
192+
AOp{},
193+
BOp{},
194+
COp{});
224195

225196
if(!gemm.IsSupportedArgument(argument))
226197
{
@@ -233,7 +204,7 @@ int main(int argc, char* argv[])
233204

234205
std::size_t flop = std::size_t(2) * M * N * K;
235206
std::size_t num_btype =
236-
sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N;
207+
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
237208

238209
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
239210

@@ -246,7 +217,7 @@ int main(int argc, char* argv[])
246217

247218
if(do_verification)
248219
{
249-
host_gemm_mk_kn_mn(a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, Relu{});
220+
host_verify(a_m_k, b_k_n, c_m_n_host_result, AOp{}, BOp{}, COp{});
250221

251222
check_error(c_m_n_host_result, c_m_n_device_result);
252223
}

example/2_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,42 @@
2020
// 0 in the "n" dimension
2121
// assume C1 and C have same layout C
2222

23+
struct BiasReluAdd
24+
{
25+
template <typename T1, typename T2>
26+
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
27+
{
28+
float b = v0 + v1;
29+
float c = b > 0 ? b : 0;
30+
float d = c + v2;
31+
32+
return d;
33+
}
34+
35+
template <typename T1, typename T2>
36+
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
37+
{
38+
#if 0
39+
float a = v1 + v0;
40+
float b = max(a, float(0));
41+
float c = b + v2;
42+
43+
return c;
44+
#else
45+
float a = v1 + v2;
46+
float b = v2;
47+
48+
float c = (v0 > -v1) ? a + v0 : v2;
49+
50+
return c;
51+
#endif
52+
}
53+
};
54+
2355
// v0 is from A * B
2456
// v1 is from C0
2557
// v2 is from C1
26-
struct BiasReluAdd
58+
struct BiasLeakyReluAdd
2759
{
2860
template <typename T1, typename T2>
2961
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
@@ -51,7 +83,7 @@ struct BiasReluAdd
5183
}
5284
};
5385

54-
struct BiasRelu
86+
struct BiasLeakyRelu
5587
{
5688
template <typename T1, typename T2>
5789
__host__ constexpr float operator()(float v0, T1 v1, T2) const
@@ -99,7 +131,7 @@ struct BiasAdd
99131
}
100132
#elif 0
101133
float alpha = 0.1;
102-
float beta = 0.2;
134+
float beta = 0.2;
103135
float gamma = 0.3;
104136

105137
// wrong result

example/4_conv_xdl_bias_relu_add/conv_xdl_bias_relu_add.cpp

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct PassThrough
2323
}
2424
};
2525

26-
struct BiasReluAdd
26+
struct BiasLeakyReluAdd
2727
{
2828
template <typename T1, typename T2>
2929
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
@@ -97,7 +97,39 @@ struct BiasReluAdd
9797
}
9898
};
9999

100-
struct BiasRelu
100+
struct BiasReluAdd
101+
{
102+
template <typename T1, typename T2>
103+
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
104+
{
105+
float b = v0 + v1;
106+
float c = b > 0 ? b : 0;
107+
float d = c + v2;
108+
109+
return d;
110+
}
111+
112+
template <typename T1, typename T2>
113+
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
114+
{
115+
#if 0
116+
float a = v1 + v0;
117+
float b = max(a, float(0));
118+
float c = b + v2;
119+
120+
return c;
121+
#else
122+
float a = v1 + v2;
123+
float b = v2;
124+
125+
float c = (v0 > -v1) ? a + v0 : v2;
126+
127+
return c;
128+
#endif
129+
}
130+
};
131+
132+
struct BiasLeakyRelu
101133
{
102134
template <typename T1, typename T2>
103135
__host__ constexpr float operator()(float v0, T1 v1, T2) const
@@ -377,6 +409,7 @@ int main(int argc, char* argv[])
377409

378410
std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) +
379411
sizeof(WeiDataType) * (K * C * Y * X) +
412+
sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K) +
380413
sizeof(OutDataType) * (N * K * Ho * Wo);
381414

382415
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

0 commit comments

Comments
 (0)