Skip to content

Commit 434dce2

Browse files
committed
refactor code according to comments
1 parent acd5756 commit 434dce2

File tree

2 files changed

+146
-366
lines changed

2 files changed

+146
-366
lines changed

torchao/csrc/cpu/int8_sdpa.cpp

Lines changed: 99 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -725,9 +725,12 @@ inline void copy_value_with_pad(
725725
}
726726

727727
// UINT8 - one parallel loop with u8u8s32 GEMM
728-
template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_split_size>
728+
template <typename scalar_t, typename mask_t,
729+
int64_t q_split_size, int64_t kv_split_size,
730+
bool use_one_parallel_loop,
731+
typename std::enable_if_t<use_one_parallel_loop, int> = 0>
729732
inline typename std::enable_if_t<std::is_same_v<scalar_t, unsigned char>, void>
730-
sdpa_int8_kernel_one_loop_impl(
733+
sdpa_int8_fused_kernel_impl(
731734
const at::Tensor& output,
732735
const at::Tensor& q,
733736
const at::Tensor& k,
@@ -1150,9 +1153,12 @@ sdpa_int8_kernel_one_loop_impl(
11501153
}
11511154

11521155
// UINT8 - several parallel loops with u8u8s32 GEMM
1153-
template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_split_size>
1156+
template <typename scalar_t, typename mask_t,
1157+
int64_t q_split_size, int64_t kv_split_size,
1158+
bool use_one_parallel_loop,
1159+
typename std::enable_if_t<!use_one_parallel_loop, int> = 0>
11541160
inline typename std::enable_if_t<std::is_same_v<scalar_t, unsigned char>, void>
1155-
sdpa_int8_kernel_several_loops_impl(
1161+
sdpa_int8_fused_kernel_impl(
11561162
const at::Tensor& output,
11571163
const at::Tensor& q,
11581164
const at::Tensor& k,
@@ -1615,6 +1621,53 @@ sdpa_int8_kernel_several_loops_impl(
16151621
at::native::cpublas::brgemm_release();
16161622
}
16171623

1624+
1625+
template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_split_size>
1626+
inline typename std::enable_if_t<std::is_same_v<scalar_t, unsigned char>, void>
1627+
sdpa_int8_fused_kernel_impl(
1628+
bool use_one_parallel_loop,
1629+
const at::Tensor& output,
1630+
const at::Tensor& query,
1631+
const at::Tensor& key,
1632+
const at::Tensor& value,
1633+
double dropout_p,
1634+
bool is_causal,
1635+
std::optional<at::Tensor> attn_mask,
1636+
double scale,
1637+
int32_t q_zp,
1638+
float q_scale,
1639+
int32_t k_zp,
1640+
float k_scale,
1641+
int32_t v_zp,
1642+
float v_scale,
1643+
int32_t a_zp,
1644+
float a_scale,
1645+
int32_t o_zp,
1646+
float o_scale) {
1647+
if (use_one_parallel_loop) {
1648+
sdpa_int8_fused_kernel_impl<scalar_t, mask_t, q_split_size, kv_split_size,
1649+
/* use_one_parallel_loop */ true>(
1650+
output, query, key, value,
1651+
dropout_p, is_causal, attn_mask, scale,
1652+
q_zp, q_scale,
1653+
k_zp, k_scale,
1654+
v_zp, v_scale,
1655+
a_zp, a_scale,
1656+
o_zp, o_scale);
1657+
} else {
1658+
sdpa_int8_fused_kernel_impl<scalar_t, mask_t, q_split_size, kv_split_size,
1659+
/* use_one_parallel_loop */ false>(
1660+
output, query, key, value,
1661+
dropout_p, is_causal, attn_mask, scale,
1662+
q_zp, q_scale,
1663+
k_zp, k_scale,
1664+
v_zp, v_scale,
1665+
a_zp, a_scale,
1666+
o_zp, o_scale);
1667+
}
1668+
}
1669+
1670+
16181671
#define AT_DISPATCH_MASK_TYPES(TYPE, NAME, ...) \
16191672
AT_DISPATCH_SWITCH( \
16201673
TYPE, \
@@ -1661,77 +1714,50 @@ void sdpa_int8_fused_kernel(
16611714
q_split_size = 64;
16621715
}
16631716
// Heuristic to decide whether to use one parallel loop or not
1717+
// true: one parallel loop for sum+packing+core
1718+
// false: three parallel loops for sum, packing, core
16641719
uint32_t l2_cache_size = at::cpu::L2_cache_size();
16651720
int64_t num_thread = at::get_num_threads();
16661721
int64_t attn_size = q_split_size * kv_seq_len * sizeof(int32_t) * num_thread;
16671722
bool use_one_parallel_loop = (batchSize * num_head > num_thread) &&
16681723
(attn_size > 1.5 * l2_cache_size);
1669-
if (use_one_parallel_loop) {
1670-
if (!attn_mask.has_value()) {
1671-
if (q_split_size == 256) {
1672-
sdpa_int8_kernel_one_loop_impl<unsigned char, float, 256, 64>(
1673-
output, query, key, value,
1674-
dropout_p, is_causal, attn_mask, scale,
1675-
q_zp, q_scale,
1676-
k_zp, k_scale,
1677-
v_zp, v_scale,
1678-
a_zp, a_scale,
1679-
o_zp, o_scale);
1680-
} else if (q_split_size == 64) {
1681-
sdpa_int8_kernel_one_loop_impl<unsigned char, float, 64, 64>(
1682-
output, query, key, value,
1683-
dropout_p, is_causal, attn_mask, scale,
1684-
q_zp, q_scale,
1685-
k_zp, k_scale,
1686-
v_zp, v_scale,
1687-
a_zp, a_scale,
1688-
o_zp, o_scale);
1689-
} else {
1690-
sdpa_int8_kernel_one_loop_impl<unsigned char, float, 32, 64>(
1691-
output, query, key, value,
1692-
dropout_p, is_causal, attn_mask, scale,
1693-
q_zp, q_scale,
1694-
k_zp, k_scale,
1695-
v_zp, v_scale,
1696-
a_zp, a_scale,
1697-
o_zp, o_scale);
1698-
}
1724+
if (!attn_mask.has_value()) {
1725+
if (q_split_size == 256) {
1726+
sdpa_int8_fused_kernel_impl<unsigned char, float, 256, 64>(
1727+
use_one_parallel_loop,
1728+
output, query, key, value,
1729+
dropout_p, is_causal, attn_mask, scale,
1730+
q_zp, q_scale,
1731+
k_zp, k_scale,
1732+
v_zp, v_scale,
1733+
a_zp, a_scale,
1734+
o_zp, o_scale);
1735+
} else if (q_split_size == 64) {
1736+
sdpa_int8_fused_kernel_impl<unsigned char, float, 64, 64>(
1737+
use_one_parallel_loop,
1738+
output, query, key, value,
1739+
dropout_p, is_causal, attn_mask, scale,
1740+
q_zp, q_scale,
1741+
k_zp, k_scale,
1742+
v_zp, v_scale,
1743+
a_zp, a_scale,
1744+
o_zp, o_scale);
16991745
} else {
1700-
AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "sdpa_mask", [&]() {
1701-
if (q_split_size == 256) {
1702-
sdpa_int8_kernel_one_loop_impl<unsigned char, mask_t, 256, 64>(
1703-
output, query, key, value,
1704-
dropout_p, is_causal, attn_mask, scale,
1705-
q_zp, q_scale,
1706-
k_zp, k_scale,
1707-
v_zp, v_scale,
1708-
a_zp, a_scale,
1709-
o_zp, o_scale);
1710-
} else if (q_split_size == 64) {
1711-
sdpa_int8_kernel_one_loop_impl<unsigned char, mask_t, 64, 64>(
1712-
output, query, key, value,
1713-
dropout_p, is_causal, attn_mask, scale,
1714-
q_zp, q_scale,
1715-
k_zp, k_scale,
1716-
v_zp, v_scale,
1717-
a_zp, a_scale,
1718-
o_zp, o_scale);
1719-
} else {
1720-
sdpa_int8_kernel_one_loop_impl<unsigned char, mask_t, 32, 64>(
1721-
output, query, key, value,
1722-
dropout_p, is_causal, attn_mask, scale,
1723-
q_zp, q_scale,
1724-
k_zp, k_scale,
1725-
v_zp, v_scale,
1726-
a_zp, a_scale,
1727-
o_zp, o_scale);
1728-
}
1729-
});
1746+
sdpa_int8_fused_kernel_impl<unsigned char, float, 32, 64>(
1747+
use_one_parallel_loop,
1748+
output, query, key, value,
1749+
dropout_p, is_causal, attn_mask, scale,
1750+
q_zp, q_scale,
1751+
k_zp, k_scale,
1752+
v_zp, v_scale,
1753+
a_zp, a_scale,
1754+
o_zp, o_scale);
17301755
}
17311756
} else {
1732-
if (!attn_mask.has_value()) {
1757+
AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "sdpa_mask", [&]() {
17331758
if (q_split_size == 256) {
1734-
sdpa_int8_kernel_several_loops_impl<unsigned char, float, 256, 64>(
1759+
sdpa_int8_fused_kernel_impl<unsigned char, mask_t, 256, 64>(
1760+
use_one_parallel_loop,
17351761
output, query, key, value,
17361762
dropout_p, is_causal, attn_mask, scale,
17371763
q_zp, q_scale,
@@ -1740,7 +1766,8 @@ void sdpa_int8_fused_kernel(
17401766
a_zp, a_scale,
17411767
o_zp, o_scale);
17421768
} else if (q_split_size == 64) {
1743-
sdpa_int8_kernel_several_loops_impl<unsigned char, float, 64, 64>(
1769+
sdpa_int8_fused_kernel_impl<unsigned char, mask_t, 64, 64>(
1770+
use_one_parallel_loop,
17441771
output, query, key, value,
17451772
dropout_p, is_causal, attn_mask, scale,
17461773
q_zp, q_scale,
@@ -1749,7 +1776,8 @@ void sdpa_int8_fused_kernel(
17491776
a_zp, a_scale,
17501777
o_zp, o_scale);
17511778
} else {
1752-
sdpa_int8_kernel_several_loops_impl<unsigned char, float, 32, 64>(
1779+
sdpa_int8_fused_kernel_impl<unsigned char, mask_t, 32, 64>(
1780+
use_one_parallel_loop,
17531781
output, query, key, value,
17541782
dropout_p, is_causal, attn_mask, scale,
17551783
q_zp, q_scale,
@@ -1758,38 +1786,7 @@ void sdpa_int8_fused_kernel(
17581786
a_zp, a_scale,
17591787
o_zp, o_scale);
17601788
}
1761-
} else {
1762-
AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "sdpa_mask", [&]() {
1763-
if (q_split_size == 256) {
1764-
sdpa_int8_kernel_several_loops_impl<unsigned char, mask_t, 256, 64>(
1765-
output, query, key, value,
1766-
dropout_p, is_causal, attn_mask, scale,
1767-
q_zp, q_scale,
1768-
k_zp, k_scale,
1769-
v_zp, v_scale,
1770-
a_zp, a_scale,
1771-
o_zp, o_scale);
1772-
} else if (q_split_size == 64) {
1773-
sdpa_int8_kernel_several_loops_impl<unsigned char, mask_t, 64, 64>(
1774-
output, query, key, value,
1775-
dropout_p, is_causal, attn_mask, scale,
1776-
q_zp, q_scale,
1777-
k_zp, k_scale,
1778-
v_zp, v_scale,
1779-
a_zp, a_scale,
1780-
o_zp, o_scale);
1781-
} else {
1782-
sdpa_int8_kernel_several_loops_impl<unsigned char, mask_t, 32, 64>(
1783-
output, query, key, value,
1784-
dropout_p, is_causal, attn_mask, scale,
1785-
q_zp, q_scale,
1786-
k_zp, k_scale,
1787-
v_zp, v_scale,
1788-
a_zp, a_scale,
1789-
o_zp, o_scale);
1790-
}
1791-
});
1792-
}
1789+
});
17931790
}
17941791
}
17951792
#endif // CPU_CAPABILITY_AVX512
@@ -1888,22 +1885,16 @@ at::Tensor _scaled_dot_product_int8_cpu(
18881885
o_zp, o_scale);
18891886
return output.transpose(1, 2);
18901887
} else {
1888+
#endif // CPU_CAPABILITY_AVX512
18911889
return sdpa_int8_math_kernel(query, key, value,
18921890
dropout_p, is_causal, attn_mask, scale,
18931891
q_zp, q_scale,
18941892
k_zp, k_scale,
18951893
v_zp, v_scale,
18961894
a_zp, a_scale,
18971895
o_zp, o_scale).transpose(1, 2).contiguous().transpose(1, 2);
1896+
#ifdef CPU_CAPABILITY_AVX512
18981897
}
1899-
#else
1900-
return sdpa_int8_math_kernel(query, key, value,
1901-
dropout_p, is_causal, attn_mask, scale,
1902-
q_zp, q_scale,
1903-
k_zp, k_scale,
1904-
v_zp, v_scale,
1905-
a_zp, a_scale,
1906-
o_zp, o_scale).transpose(1, 2).contiguous().transpose(1, 2);
19071898
#endif // CPU_CAPABILITY_AVX512
19081899
}
19091900

0 commit comments

Comments
 (0)