Skip to content

Commit 2ef3ae5

Browse files
committed
Remove old CK Tile Stream-K implementation
The original CK Stream-K implementation used a Tile Partitioner that was based on old CK's Stream-K block to C tile map. However, old CK's implementation did not align with the original Stream-K paper. Thus, we implemented a new Tile Partitioner and associated Stream-K Kernel. The kernel implementation was placed in the reboot namespace. Now that all functionality for the new implementation is in place, this change makes the following changes: - Removes all uses of the old CK Tile Stream-K implementation. - Removes the reboot namespace such that the new implementation is in the ck_tile namespace only. - Adds tests for fp8 and bf8 for the new implementation as these were only in place for the old implementation. - Removes the old CK Tile Stream-K Tile partitioner - Remove the v2 suffix from the new CK Tile Tile Partitioner derived classes.
1 parent f9f9ae8 commit 2ef3ae5

File tree

156 files changed

+868
-3813
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

156 files changed

+868
-3813
lines changed

example/ck_tile/40_streamk_gemm/run_gemm_example.inc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,15 @@ invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
7171
bool flush_cache,
7272
ck_tile::StreamKReductionStrategy reduction_strategy)
7373
{
74-
ck_tile::reboot::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
75-
b_k_n_dev_buf.GetDeviceBuffer(),
76-
c_m_n_dev_buf.GetDeviceBuffer(),
77-
M,
78-
N,
79-
K,
80-
stride_A,
81-
stride_B,
82-
stride_C};
74+
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
75+
b_k_n_dev_buf.GetDeviceBuffer(),
76+
c_m_n_dev_buf.GetDeviceBuffer(),
77+
M,
78+
N,
79+
K,
80+
stride_A,
81+
stride_B,
82+
stride_C};
8383

8484
std::tuple<float, ck_tile::index_t> ave_time_and_batch;
8585

example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ template <typename GemmConfig,
1616
typename ELayout,
1717
typename CDEElementWise,
1818
ck_tile::StreamKReductionStrategy ReductionStrategy>
19-
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::reboot::StreamKHostArgs& args,
19+
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
2020
const ck_tile::stream_config& s)
2121
{
2222
using GemmShape = ck_tile::TileGemmShape<
@@ -28,7 +28,7 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::reboot::StreamKHostArgs&
2828
GemmConfig::PermuteB>;
2929

3030
using TilePartitioner =
31-
ck_tile::StreamKTilePartitioner_v2<GemmShape, ReductionStrategy, GemmConfig::Persistent>;
31+
ck_tile::StreamKTilePartitioner<GemmShape, ReductionStrategy, GemmConfig::Persistent>;
3232

3333
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
3434
GemmConfig::kPadN,
@@ -77,7 +77,7 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::reboot::StreamKHostArgs&
7777
memory_operation.value,
7878
GemmConfig::NumWaveGroups>>;
7979

80-
using Kernel = ck_tile::reboot::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
80+
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
8181

8282
auto kargs = Kernel::MakeKernelArgs(args);
8383
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);

include/ck_tile/ops/common/streamk_common.hpp

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,33 +11,4 @@ enum StreamKReductionStrategy : uint32_t
1111
Atomic = 0u,
1212
Reduction = 1u
1313
};
14-
15-
/**
16-
* @brief Estimates the number of Stream-K workgroups per macro tile in the C tensor.
17-
*
18-
* @param sk_ctas Number of Stream-K workgroups.
19-
* @param iters_per_sk_cta Number of iterations per Stream-K workgroup.
20-
* @param iters_per_tile Number of iterations per tile (i.e., the number of macro tiles in the K
21-
* dimension).
22-
* @return ck_tile::index_t An estimate of the number of workgroups per macro tile in the C tensor.
23-
* @note It is assumed that `iters_per_sk_cta` > 0.
24-
*/
25-
template <ck_tile::StreamKReductionStrategy ReductionStrategy>
26-
ck_tile::index_t
27-
estimate_num_wgs_per_tile(index_t sk_ctas, index_t iters_per_sk_cta, index_t iters_per_tile)
28-
{
29-
// In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup
30-
// writing final results to a given macro tile in C.
31-
int num_wgs_per_tile = 1;
32-
33-
// Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C.
34-
if(sk_ctas > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
35-
{
36-
// Estimate the number of workgroups per macro tile.
37-
num_wgs_per_tile =
38-
(iters_per_tile / iters_per_sk_cta) + ((iters_per_tile % iters_per_sk_cta) != 0);
39-
}
40-
41-
return std::max(num_wgs_per_tile, 1);
42-
}
4314
} // namespace ck_tile

0 commit comments

Comments
 (0)