Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ template <typename GemmConfig,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename QuantGroupSize,
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped>
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
Expand Down Expand Up @@ -75,7 +76,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
AccDataType,
GemmShape,
GemmUniversalTraits,
128>, // QuantGroupSize
QuantGroupSize>, // QuantGroupSize
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ template <typename GemmConfig,
typename BLayout,
typename BQLayout,
typename CLayout,
typename QuantGroupSize,
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm(int n_warmup,
Expand Down Expand Up @@ -104,6 +105,7 @@ float invoke_gemm(int n_warmup,
BQDataType,
AccDataType,
CDataType,
QuantGroupSize,
QuantMode>(stream, group_count, kargs_ptr);

std::string op_name = "Quant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")";
Expand Down Expand Up @@ -134,6 +136,7 @@ template <typename GemmConfig,
typename BQDataType,
typename CDataType,
typename AccDataType,
typename QuantGroupSize,
ck_tile::QuantType QuantMode,
typename ALayout,
typename AQLayout,
Expand All @@ -159,13 +162,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
return group_count != 0 && ((args.size() == static_cast<size_t>(group_count)) && ...);
};

const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
const int kbatch = arg_parser.get_int("kbatch");
const int init_method = arg_parser.get_int("init");
bool validate = arg_parser.get_bool("validate");
const ck_tile::index_t QuantGroupSize = 128;
const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
const int kbatch = arg_parser.get_int("kbatch");
const int init_method = arg_parser.get_int("init");
bool validate = arg_parser.get_bool("validate");

if(kbatch > 1 && validate && warmup + repeat > 1)
{
Expand Down Expand Up @@ -259,9 +261,9 @@ int run_grouped_gemm_example_with_layouts(int argc,
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
AQK = 0; // No A quantization
BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize
if(K % QuantGroupSize != 0)
AQK = 0; // No A quantization
BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize
if(K % QuantGroupSize::kK != 0)
{
throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode");
}
Expand Down Expand Up @@ -400,6 +402,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
BLayout,
BQLayout,
CLayout,
QuantGroupSize,
QuantMode>(warmup, repeat, group_count, gemm_descs);

for(int i = 0; i < group_count; i++)
Expand Down Expand Up @@ -481,12 +484,14 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Types = GemmTypeConfig<PrecType>;
// Specific type aliases for easy access
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;
using AQDataType = typename Types::AccDataType;
using BQDataType = typename Types::AccDataType;
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;
using AQDataType = typename Types::AccDataType;
using BQDataType = typename Types::AccDataType;
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;

if(a_layout == "R" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
Expand All @@ -496,6 +501,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
BQDataType,
CDataType,
AccDataType,
QuantGroupSize,
QuantMode>(
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
}
Expand Down
13 changes: 8 additions & 5 deletions example/ck_tile/38_block_scale_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
add_executable(${EXE_NAME} EXCLUDE_FROM_ALL
gemm_quant.cpp
gemm_aquant_quantgrouped.cpp
gemm_bquant_quantgrouped_prefill_bf8i4.cpp
gemm_bquant_quantgrouped_prefill_fp8i4.cpp
gemm_bquant_quantgrouped_prefill_bf8.cpp
gemm_bquant_quantgrouped_prefill_fp8.cpp
gemm_bquant_quantgrouped_preshuffleb_prefill.cpp
gemm_aquant_quantgrouped_preshufflequant.cpp
gemm_bquant_quantgrouped_bf8i4.cpp
gemm_bquant_quantgrouped_fp8i4.cpp
gemm_bquant_quantgrouped_bf8.cpp
gemm_bquant_quantgrouped_fp8.cpp
gemm_bquant_quantgrouped_preshuffleb.cpp
gemm_bquant_quantgrouped_preshufflequant.cpp
gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp
gemm_quant_rowcol.cpp
gemm_quant_tensor.cpp
)
Expand Down
71 changes: 37 additions & 34 deletions example/ck_tile/38_block_scale_gemm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,47 +33,50 @@ mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx942) or leave it blank
../script/cmake-ck-dev.sh ../ <arch>
# Compile the quant kernels
make tile_example_gemm_quant_basic -j
make tile_example_gemm_quant -j
```
This will result in an executable `build/bin/tile_example_gemm_quant_basic`
This will result in an executable `build/bin/tile_example_gemm_quant`

## example
```
args:
-h Print help message (default:false)
-m m dimension (default:3840)
-n n dimension (default:4096)
-k k dimension (default:2048)
-a_layout A tensor data layout - Row or Column (default:R)
-b_layout B tensor data layout - Row or Column (default:C)
-bq_layout Bq tensor data layout - Row or Column (default:C)
-c_layout C tensor data layout - Row or Column (default:R)
-stride_a Tensor A stride (default:0)
-stride_q Tensor AQ stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1)
-prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8)
-warmup Number of iterations before benchmarking the kernel (default:50)
-repeat Number of iterations to benchmark the kernel (default:1000)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
-split_k SplitK value (default:1)
-device Device id that will be used to run the kernel (default:0)
-init 0:random, 1:linear, 2:constant(1) (default:0)
-flush_cache Flush cache before running the kernel (default:true)
-rotating_count Rotating count (default:1000)
-quant_mode Choose aquant, bquant, tensor or rowcol (default:bquant)
-preshuffleb Enable preshuffle of tensor B (default:false)
-group_size Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128 (default:1x1x128)
-h Print help message (default:false)
-m m dimension (default:3840)
-n n dimension (default:4096)
-k k dimension (default:2048)
-a_layout A tensor data layout - Row or Column (default:R)
-b_layout B tensor data layout - Row or Column (default:C)
-bq_layout Bq tensor data layout - Row or Column (default:C)
-c_layout C tensor data layout - Row or Column (default:R)
-stride_a Tensor A stride (default:0)
-stride_q Tensor AQ stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1)
-prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8)
-warmup Number of iterations before benchmarking the kernel (default:50)
-repeat Number of iterations to benchmark the kernel (default:1000)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
-split_k SplitK value (default:1)
-device Device id that will be used to run the kernel (default:0)
-init 0:random, 1:linear, 2:constant(1) (default:0)
-flush_cache Flush cache before running the kernel (default:true)
-rotating_count Rotating count (default:1000)
-quant_mode Choose aquant, bquant, tensor or rowcol (default:bquant)
-preshuffleb Enable preshuffle of tensor B (default:false)
-preshufflequant Enable preshuffle of quant tensor (defualt:false)
-group_size Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128 (default:1x1x128)
```

User need to select correct mapping of config for each quant mode:

| | quant_mode as runtime argument | Config in cpp file |
|:--------|:-----:|-------|
| For selecting AQuant | aquant | GemmConfigQuant |
| For selecting Aquant with Preshuffle | aquant | GemmConfigPreshuffleQuant |
| For selecting BQuant | bquant | GemmConfigQuant |
| For selecting PreShuffle Weight matrix with Bquant | bquant | GemmConfigPreshuffleB_Bquant_decode (or) GemmConfigPreshuffleB_Bquant_prefill
| For selecting RowCol quant | rowcolquant | GemmConfigRowColQuant |
| | quant_mode as runtime argument | Corresponding cpp file | GemmConfig at the top of cpp file |
|:--------|:-----:|:-----:|-------|
| For selecting AQuant | aquant | gemm_aquant_quantgrouped.cpp| GemmConfigQuantDecode |
| For selecting AQuant with Preshuffle quant | aquant | gemm_aquant_quantgrouped_preshufflequant.cpp | GemmConfigPreshuffleQuantDecode |
| For selecting BQuant | bquant | gemm_bquant_quantgrouped_<prec_type>.cpp| GemmConfigQuantDecode (or) GemmConfigBQuantPrefill |
| For selecting BQuant with Preshuffle quant | bquant | gemm_bquant_quantgrouped_preshufflequant.cpp| GemmConfigPreshuffleQuantDecode (or) GemmConfigPreshuffleBQuantPrefill |
| For selecting PreShuffle B with BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb.cpp| GemmConfigPreshuffleB_BQuant_Decode (or) GemmConfigPreshuffleB_BQuant_Prefill
| For selecting PreShuffle B with preshuffle BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp |GemmConfigPreshuffleB_PreshuffleBQuant_Decode (or) GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
| For selecting RowCol quant | rowcolquant | gemm_quant_rowcol| GemmConfigRowColQuant |

Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,33 @@
#include "run_gemm_quant_example.inc"

template <typename T>
using GemmConfig = GemmConfigQuant<T>;
using GemmConfig = GemmConfigQuantDecode<T>;

void aquant_quantgrouped_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
lut[hash_multiple_strings({"fp8", "aquant", "1x1x128"})] = [](const ck_tile::ArgParser&
arg_parser) {
lut[hash_multiple_strings(
{"fp8", "aquant", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser&
arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::AQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8", "aquant", "1x1x128"})] = [](const ck_tile::ArgParser&
arg_parser) {
lut[hash_multiple_strings(
{"bf8", "aquant", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser&
arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::AQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8i4", "aquant", "1x1x128"})] =
lut[hash_multiple_strings({"fp8i4", "aquant", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
Expand All @@ -39,7 +41,7 @@ void aquant_quantgrouped_instance_factory(
QuantGroupSize,
ck_tile::QuantType::AQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8i4", "aquant", "1x1x128"})] =
lut[hash_multiple_strings({"bf8i4", "aquant", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,52 @@
#include "run_gemm_quant_example.inc"

template <typename T>
using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill<T>;
using GemmConfig = GemmConfigPreshuffleQuantDecode<T>;

void bquant_quantgrouped_preshuffleb_instance_factory(
void aquant_quantgrouped_preshufflequant_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
lut[hash_multiple_strings(
{"fp8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
{"fp8", "aquant", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser&
arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
ck_tile::QuantType::AQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"bf8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
{"bf8", "aquant", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser&
arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
ck_tile::QuantType::AQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "1x1x128"})] =
lut[hash_multiple_strings({"fp8i4", "aquant", "preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
ck_tile::QuantType::AQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "1x1x128"})] =
lut[hash_multiple_strings({"bf8i4", "aquant", "preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
ck_tile::QuantType::AQuantGrouped>(arg_parser);
};
}
Loading