Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example/ck_tile/18_flatmm/flatmm_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ static constexpr inline auto is_row_major(Layout layout_)

// mfma_type, 0:32x32, 1:16x16
template <typename FlatmmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
auto shuffle_b_v0(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
Expand Down
2 changes: 1 addition & 1 deletion example/ck_tile/18_flatmm/run_flatmm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ int run_flatmm_example_with_layouts(int argc,
}
else
{
return shuffle_b<FlatmmConfig>(b_origin_host);
return shuffle_b_v0<FlatmmConfig>(b_origin_host);
}
}();
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
Expand Down
16 changes: 10 additions & 6 deletions include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,17 +662,21 @@ struct FlatmmKernel

const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_m_ptr.ptr,
make_tuple(
kargs.M / ScaleGranularityM,
ScaleGranularityKA == 0 ? 1 : splitk_batch_offset.splitted_k / ScaleGranularityKA),
make_tuple(kargs.M / ScaleGranularityM,
ScaleGranularityKA == 0
? 1
: splitk_batch_offset.splitted_k /
(ScaleGranularityKA != 0 ? ScaleGranularityKA : 1)),
make_tuple(scale_stride_m, 0),
number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
number<1>{});
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_n_ptr.ptr,
make_tuple(
ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB),
kargs.N / ScaleGranularityN),
make_tuple(ScaleGranularityKB == 0
? 1
: (splitk_batch_offset.splitted_k /
(ScaleGranularityKB != 0 ? ScaleGranularityKB : 1)),
kargs.N / ScaleGranularityN),
make_tuple(0, scale_stride_n),
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
number<1>{});
Expand Down