Skip to content

Commit 11c72f9

Browse files
committed
Support qscale for dynamic quant, remove static quant
1 parent 9f33b7c commit 11c72f9

File tree

14 files changed

+341
-228
lines changed

14 files changed

+341
-228
lines changed

example/ck_tile/01_fmha/codegen/cpp_symbol_map.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,15 @@ def get_mask_check_map(mask: str):
6262
assert False
6363
return None
6464

65+
QSCALE_MAP = {
66+
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
67+
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
68+
}
69+
70+
QSCALE_CHECK_MAP = {
71+
"no": "quant_scale_enum::no_scale",
72+
"pertensor": "quant_scale_enum::pertensor",
73+
}
6574

6675
BIAS_MAP = {
6776
"no": "ck_tile::BlockAttentionBiasEnum::NO_BIAS",

example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py

Lines changed: 48 additions & 44 deletions
Large diffs are not rendered by default.

example/ck_tile/01_fmha/example_fmha_fwd.cpp

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,12 @@ auto create_args(int argc, char* argv[])
4747
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
4848
.insert("scale_s",
4949
"0",
50-
"scale factor of S. 0 means equal to 1/sqrt(hdim).\n"
51-
"note when squant=1, this value will be modified")
50+
"scale factor of S. 0 means equal to 1/sqrt(hdim)")
51+
.insert("qscale",
52+
"n",
53+
"n or 0, no scale\n"
54+
"pt or 1, per-tensor scale\n")
5255
.insert("logits_soft_cap", "0", "attention logits soft capping value.")
53-
.insert("squant",
54-
"auto",
55-
"if using static quantization fusion or not. auto: fp8 will default use squant, "
56-
"other will not\n"
57-
"0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to "
58-
"P and O.\n"
59-
"calculate scale_s, scale_p, scale_o auto")
6056
.insert("iperm",
6157
"1",
6258
"permute input\n"
@@ -87,7 +83,8 @@ auto create_args(int argc, char* argv[])
8783
"uf",
8884
"init method:\n ui or 0 - uniform random int\n ni - normalized random int"
8985
"\n uf or 1 - uniform random float\n nf - normalized random float"
90-
"\n tf or 2 - trig float\n")
86+
"\n tf or 2 - trig float"
87+
"\n tf or 3 - uniform random float, min max is the max of the type\n")
9188
.insert("seed",
9289
"11939",
9390
"random seed used for initializing input tensors. 0 for "
@@ -152,6 +149,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
152149
ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size");
153150
bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx");
154151
std::string bias_str = arg_parser.get_str("bias");
152+
std::string qscale_str = arg_parser.get_str("qscale");
155153
float p_drop = arg_parser.get_float("p_drop");
156154
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
157155
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
@@ -162,13 +160,6 @@ auto run(const ck_tile::ArgParser& arg_parser)
162160
std::string init_method = arg_parser.get_str("init");
163161
uint32_t seed = arg_parser.get_uint32("seed");
164162

165-
bool squant = [&]() {
166-
if(arg_parser.get_str("squant") == "auto")
167-
return std::is_same_v<DataTypeConfig, FmhaFwdFp8>;
168-
else
169-
return arg_parser.get_bool("squant");
170-
}();
171-
172163
ck_tile::stream_config stream_config{nullptr,
173164
true,
174165
/* log_level = */ (arg_parser.get_bool("kname") ? 1 : 0),
@@ -208,7 +199,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
208199
drop_offset,
209200
drop_prefs,
210201
mask_str,
211-
squant,
202+
qscale_str,
212203
is_rotary_interleaved,
213204
num_splits,
214205
init_method,
@@ -239,10 +230,6 @@ int main(int argc, char* argv[])
239230
{
240231
return run<FmhaFwdBf16>(arg_parser) == fwd_result::success ? 0 : -2;
241232
}
242-
else if(data_type == "fp8")
243-
{
244-
return run<FmhaFwdFp8>(arg_parser) == fwd_result::success ? 0 : -2;
245-
}
246233
else if(data_type == "fp8bf16")
247234
{
248235
return run<FmhaFwdFp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;

example/ck_tile/01_fmha/fmha_fwd.hpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "bias.hpp"
1313
#include "mask.hpp"
14+
#include "quant.hpp"
1415
#include "rotary.hpp"
1516

1617
#include <type_traits>
@@ -178,6 +179,9 @@ struct fmha_fwd_args
178179
const void* k_ptr;
179180
const void* v_ptr;
180181
const void* bias_ptr; // bias or alibi_slope pointer
182+
const void* q_descale_ptr;
183+
const void* k_descale_ptr;
184+
const void* v_descale_ptr;
181185
void* rand_val_ptr;
182186
void* lse_ptr;
183187
void* o_ptr;
@@ -237,9 +241,6 @@ struct fmha_fwd_args
237241
ck_tile::index_t nhead_k;
238242

239243
float scale_s;
240-
float scale_p;
241-
float scale_o;
242-
243244
float logits_soft_cap;
244245

245246
ck_tile::index_t stride_q;
@@ -581,6 +582,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
581582
args.k_ptr,
582583
args.v_ptr,
583584
args.bias_ptr,
585+
args.q_descale_ptr,
586+
args.k_descale_ptr,
587+
args.v_descale_ptr,
584588
args.rand_val_ptr,
585589
args.lse_ptr,
586590
args.o_ptr,
@@ -593,8 +597,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
593597
args.nhead_q,
594598
args.nhead_q / args.nhead_k,
595599
args.scale_s,
596-
args.scale_p,
597-
args.scale_o,
598600
args.logits_soft_cap,
599601
args.stride_q,
600602
args.stride_k,
@@ -625,6 +627,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
625627
args.k_ptr,
626628
args.v_ptr,
627629
args.bias_ptr,
630+
args.q_descale_ptr,
631+
args.k_descale_ptr,
632+
args.v_descale_ptr,
628633
args.rand_val_ptr,
629634
args.lse_ptr,
630635
args.o_ptr,
@@ -635,8 +640,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
635640
args.nhead_q,
636641
args.nhead_q / args.nhead_k,
637642
args.scale_s,
638-
args.scale_p,
639-
args.scale_o,
640643
args.logits_soft_cap,
641644
args.stride_q,
642645
args.stride_k,
@@ -1125,7 +1128,7 @@ template <ck_tile::index_t HDim_,
11251128
ck_tile::BlockAttentionBiasEnum BiasEnum_,
11261129
bool kStoreLse_,
11271130
bool kHasDropout_,
1128-
bool kDoFp8StaticQuant_,
1131+
ck_tile::BlockAttentionQuantScaleEnum QScaleEnum_,
11291132
bool kPadS_,
11301133
bool kPadSK_,
11311134
bool kPadD_,
@@ -1150,7 +1153,7 @@ struct fmha_fwd_traits_
11501153
static constexpr auto BiasEnum = BiasEnum_;
11511154
static constexpr bool kStoreLse = kStoreLse_;
11521155
static constexpr bool kHasDropout = kHasDropout_;
1153-
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
1156+
static constexpr auto QScaleEnum = QScaleEnum_;
11541157
static constexpr bool kPadS = kPadS_;
11551158
static constexpr bool kPadSK = kPadSK_;
11561159
static constexpr bool kPadD = kPadD_;
@@ -1341,7 +1344,7 @@ struct fmha_fwd_traits
13411344
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
13421345
bool has_lse;
13431346
bool has_dropout;
1344-
bool do_fp8_static_quant;
1347+
quant_scale_enum qscale_type;
13451348
bool skip_min_seqlen_q = false;
13461349
// TODO: padding check is inside this api
13471350
};

0 commit comments

Comments
 (0)