@@ -25,115 +25,76 @@ struct PassThrough
2525
2626struct Relu
2727{
28- float alpha = 0.1 ;
29-
30- // ReLU
3128 template <typename T>
3229 __host__ __device__ constexpr T operator ()(T v) const
3330 {
34- T tmp = alpha * v;
35- return tmp > 0 ? tmp : 0 ;
31+ return v > 0 ? v : 0 ;
3632 }
3733};
3834
39- template <typename ADataType,
40- typename BDataType,
41- typename CDataType,
42- typename ALayout,
43- typename BLayout,
44- typename CLayout,
35+ template <ck::index_t ... Is>
36+ using S = ck::Sequence<Is...>;
37+
38+ using ADataType = ck::half_t ;
39+ using BDataType = ck::half_t ;
40+ using CDataType = ck::half_t ;
41+ using AccDataType = float ;
42+
43+ using ALayout = ck::tensor_layout::gemm::RowMajor;
44+ using BLayout = ck::tensor_layout::gemm::ColumnMajor;
45+ using CLayout = ck::tensor_layout::gemm::RowMajor;
46+
47+ using AOp = PassThrough;
48+ using BOp = PassThrough;
49+ using COp = Relu;
50+
51+ // Compilation parameters for NT problem
52+ // clang-format off
53+ using DeviceGemmInstance =
54+ // #########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
55+ // #########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
56+ // #########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
57+ // #########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
58+ ck::tensor_operation::device::DeviceGemmXdl< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AOp, BOp, COp, 256 , 256 , 128 , 4 , 8 , 32 , 32 , 4 , 2 , S<1 , 4 , 8 >, S<4 , 64 , 1 >, S<1 , 0 , 2 >, S<1 , 0 , 2 >, 2 , 8 , 8 , S<1 , 2 , 8 >, S<4 , 64 , 1 >, S<1 , 0 , 2 >, S<1 , 0 , 2 >, 2 , 8 , 8 , 7 , 1 , true , true >;
59+ // clang-format on
60+
61+ template <typename AType,
62+ typename BType,
63+ typename CType,
4564 typename AElementwiseOperation,
4665 typename BElementwiseOperation,
4766 typename CElementwiseOperation>
48- struct DeviceGemmInstance ;
49-
50- template <typename AElementwiseOperation,
51- typename BElementwiseOperation,
52- typename CElementwiseOperation>
53- struct DeviceGemmInstance <ck::half_t ,
54- ck::half_t ,
55- ck::half_t ,
56- ck::tensor_layout::gemm::RowMajor,
57- ck::tensor_layout::gemm::ColumnMajor,
58- ck::tensor_layout::gemm::RowMajor,
59- AElementwiseOperation,
60- BElementwiseOperation,
61- CElementwiseOperation>
67+ static void host_verify (const Tensor<AType>& a_m_k,
68+ const Tensor<BType>& b_k_n,
69+ Tensor<CType>& c_m_n,
70+ const AElementwiseOperation& a_element_op,
71+ const BElementwiseOperation& b_element_op,
72+ const CElementwiseOperation& c_element_op)
6273{
63- using F16 = ck::half_t ;
64- using F32 = float ;
65-
66- using Row = ck::tensor_layout::gemm::RowMajor;
67- using Col = ck::tensor_layout::gemm::ColumnMajor;
68-
69- template <ck::index_t ... Is>
70- using S = ck::Sequence<Is...>;
71-
72- using AOp = AElementwiseOperation;
73- using BOp = BElementwiseOperation;
74- using COp = CElementwiseOperation;
75-
76- // Compilation parameters for NT problem
77- // clang-format off
78- using type =
79- // ########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
80- // ########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
81- // ########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
82- // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
83- ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, AOp, BOp, COp, 256 , 256 , 128 , 4 , 8 , 32 , 32 , 4 , 2 , S<1 , 4 , 8 >, S<4 , 64 , 1 >, S<1 , 0 , 2 >, S<1 , 0 , 2 >, 2 , 8 , 8 , S<1 , 2 , 8 >, S<4 , 64 , 1 >, S<1 , 0 , 2 >, S<1 , 0 , 2 >, 2 , 8 , 8 , 7 , 1 , true , true >;
84- // clang-format on
85- };
74+ auto f_mk_kn_mn = [&](auto m, auto n) {
75+ const int K = a_m_k.mDesc .GetLengths ()[1 ];
8676
87- template <typename AElementwiseOperation,
88- typename BElementwiseOperation,
89- typename CElementwiseOperation>
90- struct DeviceGemmInstance <float ,
91- float ,
92- float ,
93- ck::tensor_layout::gemm::RowMajor,
94- ck::tensor_layout::gemm::ColumnMajor,
95- ck::tensor_layout::gemm::RowMajor,
96- AElementwiseOperation,
97- BElementwiseOperation,
98- CElementwiseOperation>
99- {
100- using F16 = ck::half_t ;
101- using F32 = float ;
102-
103- using Row = ck::tensor_layout::gemm::RowMajor;
104- using Col = ck::tensor_layout::gemm::ColumnMajor;
105-
106- template <ck::index_t ... Is>
107- using S = ck::Sequence<Is...>;
108-
109- using AOp = AElementwiseOperation;
110- using BOp = BElementwiseOperation;
111- using COp = CElementwiseOperation;
112-
113- // Compilation parameters for NT problem
114- // clang-format off
115- using type =
116- // ########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
117- // ########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
118- // ########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
119- // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
120- ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, AOp, BOp, COp, 256 , 256 , 128 , 4 , 4 , 32 , 32 , 4 , 2 , S<1 , 4 , 4 >, S<4 , 64 , 1 >, S<1 , 0 , 2 >, S<1 , 0 , 2 >, 2 , 4 , 4 , S<1 , 2 , 4 >, S<4 , 64 , 1 >, S<1 , 0 , 2 >, S<1 , 0 , 2 >, 2 , 4 , 4 , 7 , 1 , true , true >;
121- // clang-format on
122- };
77+ double v = 0 ;
78+
79+ for (int k = 0 ; k < K; ++k)
80+ {
81+ v += static_cast <const double >(a_element_op (a_m_k (m, k))) *
82+ static_cast <const double >(b_element_op (b_k_n (k, n)));
83+ }
84+
85+ c_m_n (m, n) = c_element_op (v);
86+ };
87+
88+ make_ParallelTensorFunctor (f_mk_kn_mn,
89+ c_m_n.mDesc .GetLengths ()[0 ],
90+ c_m_n.mDesc .GetLengths ()[1 ])(std::thread::hardware_concurrency ());
91+ }
12392
12493int main (int argc, char * argv[])
12594{
126- if (argc != 4 )
127- {
128- printf (" arg1: verification (0=no, 1=yes)\n " );
129- printf (" arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n " );
130- printf (" arg3: run kernel # of times (>1)\n " );
131- exit (0 );
132- }
133-
134- const bool do_verification = std::stoi (argv[1 ]);
135- const int init_method = std::stoi (argv[2 ]);
136- const int nrepeat = std::stoi (argv[3 ]);
95+ bool do_verification = 0 ;
96+ int init_method = 0 ;
97+ int nrepeat = 5 ;
13798
13899 // GEMM shape
139100 ck::index_t M = 3840 ;
@@ -144,15 +105,34 @@ int main(int argc, char* argv[])
144105 ck::index_t StrideB = 4096 ;
145106 ck::index_t StrideC = 4096 ;
146107
147- // matrix data type
148- using ADataType = ck::half_t ;
149- using BDataType = ck::half_t ;
150- using CDataType = ck::half_t ;
108+ if (argc == 4 )
109+ {
110+ M = std::stoi (argv[4 ]);
111+ N = std::stoi (argv[5 ]);
112+ K = std::stoi (argv[6 ]);
113+ }
114+ else if (argc == 10 )
115+ {
116+ do_verification = std::stoi (argv[1 ]);
117+ init_method = std::stoi (argv[2 ]);
118+ nrepeat = std::stoi (argv[3 ]);
119+
120+ M = std::stoi (argv[4 ]);
121+ N = std::stoi (argv[5 ]);
122+ K = std::stoi (argv[6 ]);
151123
152- // matrix layout
153- using ALayout = ck::tensor_layout::gemm::RowMajor;
154- using BLayout = ck::tensor_layout::gemm::ColumnMajor;
155- using CLayout = ck::tensor_layout::gemm::RowMajor;
124+ StrideA = std::stoi (argv[7 ]);
125+ StrideB = std::stoi (argv[8 ]);
126+ StrideC = std::stoi (argv[9 ]);
127+ }
128+ else
129+ {
130+ printf (" arg1: verification (0=no, 1=yes)\n " );
131+ printf (" arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n " );
132+ printf (" arg3: run kernel # of times (>1)\n " );
133+ printf (" arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n " );
134+ exit (0 );
135+ }
156136
157137 auto f_host_tensor_descriptor =
158138 [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
@@ -198,16 +178,7 @@ int main(int argc, char* argv[])
198178 c_m_n_device_buf.ToDevice (c_m_n_device_result.mData .data ());
199179
200180 // do GEMM
201- auto gemm = typename DeviceGemmInstance<ADataType,
202- BDataType,
203- CDataType,
204- ALayout,
205- BLayout,
206- CLayout,
207- PassThrough,
208- PassThrough,
209- Relu>::type{};
210-
181+ auto gemm = DeviceGemmInstance{};
211182 auto invoker = gemm.MakeInvoker ();
212183 auto argument = gemm.MakeArgument (static_cast <ADataType*>(a_m_k_device_buf.GetDeviceBuffer ()),
213184 static_cast <BDataType*>(b_k_n_device_buf.GetDeviceBuffer ()),
@@ -218,9 +189,9 @@ int main(int argc, char* argv[])
218189 StrideA,
219190 StrideB,
220191 StrideC,
221- PassThrough {},
222- PassThrough {},
223- Relu {});
192+ AOp {},
193+ BOp {},
194+ COp {});
224195
225196 if (!gemm.IsSupportedArgument (argument))
226197 {
@@ -233,7 +204,7 @@ int main(int argc, char* argv[])
233204
234205 std::size_t flop = std::size_t (2 ) * M * N * K;
235206 std::size_t num_btype =
236- sizeof (ADataType) * M * K + sizeof (BDataType) * K * M + sizeof (CDataType) * M * N;
207+ sizeof (ADataType) * M * K + sizeof (BDataType) * K * N + sizeof (CDataType) * M * N;
237208
238209 float tflops = static_cast <float >(flop) / 1 .E9 / ave_time;
239210
@@ -246,7 +217,7 @@ int main(int argc, char* argv[])
246217
247218 if (do_verification)
248219 {
249- host_gemm_mk_kn_mn (a_m_k, b_k_n, c_m_n_host_result, PassThrough {}, PassThrough {}, Relu {});
220+ host_verify (a_m_k, b_k_n, c_m_n_host_result, AOp {}, BOp {}, COp {});
250221
251222 check_error (c_m_n_host_result, c_m_n_device_result);
252223 }
0 commit comments