@@ -725,9 +725,12 @@ inline void copy_value_with_pad(
725
725
}
726
726
727
727
// UINT8 - one parallel loop with u8u8s32 GEMM
728
- template <typename scalar_t , typename mask_t , int64_t q_split_size, int64_t kv_split_size>
728
+ template <typename scalar_t , typename mask_t ,
729
+ int64_t q_split_size, int64_t kv_split_size,
730
+ bool use_one_parallel_loop,
731
+ typename std::enable_if_t <use_one_parallel_loop, int > = 0 >
729
732
inline typename std::enable_if_t <std::is_same_v<scalar_t , unsigned char >, void >
730
- sdpa_int8_kernel_one_loop_impl (
733
+ sdpa_int8_fused_kernel_impl (
731
734
const at::Tensor& output,
732
735
const at::Tensor& q,
733
736
const at::Tensor& k,
@@ -1150,9 +1153,12 @@ sdpa_int8_kernel_one_loop_impl(
1150
1153
}
1151
1154
1152
1155
// UINT8 - several parallel loops with u8u8s32 GEMM
1153
- template <typename scalar_t , typename mask_t , int64_t q_split_size, int64_t kv_split_size>
1156
+ template <typename scalar_t , typename mask_t ,
1157
+ int64_t q_split_size, int64_t kv_split_size,
1158
+ bool use_one_parallel_loop,
1159
+ typename std::enable_if_t <!use_one_parallel_loop, int > = 0 >
1154
1160
inline typename std::enable_if_t <std::is_same_v<scalar_t , unsigned char >, void >
1155
- sdpa_int8_kernel_several_loops_impl (
1161
+ sdpa_int8_fused_kernel_impl (
1156
1162
const at::Tensor& output,
1157
1163
const at::Tensor& q,
1158
1164
const at::Tensor& k,
@@ -1615,6 +1621,53 @@ sdpa_int8_kernel_several_loops_impl(
1615
1621
at::native::cpublas::brgemm_release ();
1616
1622
}
1617
1623
1624
+
1625
+ template <typename scalar_t , typename mask_t , int64_t q_split_size, int64_t kv_split_size>
1626
+ inline typename std::enable_if_t <std::is_same_v<scalar_t , unsigned char >, void >
1627
+ sdpa_int8_fused_kernel_impl (
1628
+ bool use_one_parallel_loop,
1629
+ const at::Tensor& output,
1630
+ const at::Tensor& query,
1631
+ const at::Tensor& key,
1632
+ const at::Tensor& value,
1633
+ double dropout_p,
1634
+ bool is_causal,
1635
+ std::optional<at::Tensor> attn_mask,
1636
+ double scale,
1637
+ int32_t q_zp,
1638
+ float q_scale,
1639
+ int32_t k_zp,
1640
+ float k_scale,
1641
+ int32_t v_zp,
1642
+ float v_scale,
1643
+ int32_t a_zp,
1644
+ float a_scale,
1645
+ int32_t o_zp,
1646
+ float o_scale) {
1647
+ if (use_one_parallel_loop) {
1648
+ sdpa_int8_fused_kernel_impl<scalar_t , mask_t , q_split_size, kv_split_size,
1649
+ /* use_one_parallel_loop */ true >(
1650
+ output, query, key, value,
1651
+ dropout_p, is_causal, attn_mask, scale,
1652
+ q_zp, q_scale,
1653
+ k_zp, k_scale,
1654
+ v_zp, v_scale,
1655
+ a_zp, a_scale,
1656
+ o_zp, o_scale);
1657
+ } else {
1658
+ sdpa_int8_fused_kernel_impl<scalar_t , mask_t , q_split_size, kv_split_size,
1659
+ /* use_one_parallel_loop */ false >(
1660
+ output, query, key, value,
1661
+ dropout_p, is_causal, attn_mask, scale,
1662
+ q_zp, q_scale,
1663
+ k_zp, k_scale,
1664
+ v_zp, v_scale,
1665
+ a_zp, a_scale,
1666
+ o_zp, o_scale);
1667
+ }
1668
+ }
1669
+
1670
+
1618
1671
#define AT_DISPATCH_MASK_TYPES (TYPE, NAME, ...) \
1619
1672
AT_DISPATCH_SWITCH ( \
1620
1673
TYPE, \
@@ -1661,77 +1714,50 @@ void sdpa_int8_fused_kernel(
1661
1714
q_split_size = 64 ;
1662
1715
}
1663
1716
// Heuristic to decide whether to use one parallel loop or not
1717
+ // true: one parallel loop for sum+packing+core
1718
+ // false: three parallel loops for sum, packing, core
1664
1719
uint32_t l2_cache_size = at::cpu::L2_cache_size ();
1665
1720
int64_t num_thread = at::get_num_threads ();
1666
1721
int64_t attn_size = q_split_size * kv_seq_len * sizeof (int32_t ) * num_thread;
1667
1722
bool use_one_parallel_loop = (batchSize * num_head > num_thread) &&
1668
1723
(attn_size > 1.5 * l2_cache_size);
1669
- if (use_one_parallel_loop) {
1670
- if (!attn_mask.has_value ()) {
1671
- if (q_split_size == 256 ) {
1672
- sdpa_int8_kernel_one_loop_impl<unsigned char , float , 256 , 64 >(
1673
- output, query, key, value,
1674
- dropout_p, is_causal, attn_mask, scale,
1675
- q_zp, q_scale,
1676
- k_zp, k_scale,
1677
- v_zp, v_scale,
1678
- a_zp, a_scale,
1679
- o_zp, o_scale);
1680
- } else if (q_split_size == 64 ) {
1681
- sdpa_int8_kernel_one_loop_impl<unsigned char , float , 64 , 64 >(
1682
- output, query, key, value,
1683
- dropout_p, is_causal, attn_mask, scale,
1684
- q_zp, q_scale,
1685
- k_zp, k_scale,
1686
- v_zp, v_scale,
1687
- a_zp, a_scale,
1688
- o_zp, o_scale);
1689
- } else {
1690
- sdpa_int8_kernel_one_loop_impl<unsigned char , float , 32 , 64 >(
1691
- output, query, key, value,
1692
- dropout_p, is_causal, attn_mask, scale,
1693
- q_zp, q_scale,
1694
- k_zp, k_scale,
1695
- v_zp, v_scale,
1696
- a_zp, a_scale,
1697
- o_zp, o_scale);
1698
- }
1724
+ if (!attn_mask.has_value ()) {
1725
+ if (q_split_size == 256 ) {
1726
+ sdpa_int8_fused_kernel_impl<unsigned char , float , 256 , 64 >(
1727
+ use_one_parallel_loop,
1728
+ output, query, key, value,
1729
+ dropout_p, is_causal, attn_mask, scale,
1730
+ q_zp, q_scale,
1731
+ k_zp, k_scale,
1732
+ v_zp, v_scale,
1733
+ a_zp, a_scale,
1734
+ o_zp, o_scale);
1735
+ } else if (q_split_size == 64 ) {
1736
+ sdpa_int8_fused_kernel_impl<unsigned char , float , 64 , 64 >(
1737
+ use_one_parallel_loop,
1738
+ output, query, key, value,
1739
+ dropout_p, is_causal, attn_mask, scale,
1740
+ q_zp, q_scale,
1741
+ k_zp, k_scale,
1742
+ v_zp, v_scale,
1743
+ a_zp, a_scale,
1744
+ o_zp, o_scale);
1699
1745
} else {
1700
- AT_DISPATCH_MASK_TYPES (attn_mask.value ().scalar_type (), " sdpa_mask" , [&]() {
1701
- if (q_split_size == 256 ) {
1702
- sdpa_int8_kernel_one_loop_impl<unsigned char , mask_t , 256 , 64 >(
1703
- output, query, key, value,
1704
- dropout_p, is_causal, attn_mask, scale,
1705
- q_zp, q_scale,
1706
- k_zp, k_scale,
1707
- v_zp, v_scale,
1708
- a_zp, a_scale,
1709
- o_zp, o_scale);
1710
- } else if (q_split_size == 64 ) {
1711
- sdpa_int8_kernel_one_loop_impl<unsigned char , mask_t , 64 , 64 >(
1712
- output, query, key, value,
1713
- dropout_p, is_causal, attn_mask, scale,
1714
- q_zp, q_scale,
1715
- k_zp, k_scale,
1716
- v_zp, v_scale,
1717
- a_zp, a_scale,
1718
- o_zp, o_scale);
1719
- } else {
1720
- sdpa_int8_kernel_one_loop_impl<unsigned char , mask_t , 32 , 64 >(
1721
- output, query, key, value,
1722
- dropout_p, is_causal, attn_mask, scale,
1723
- q_zp, q_scale,
1724
- k_zp, k_scale,
1725
- v_zp, v_scale,
1726
- a_zp, a_scale,
1727
- o_zp, o_scale);
1728
- }
1729
- });
1746
+ sdpa_int8_fused_kernel_impl<unsigned char , float , 32 , 64 >(
1747
+ use_one_parallel_loop,
1748
+ output, query, key, value,
1749
+ dropout_p, is_causal, attn_mask, scale,
1750
+ q_zp, q_scale,
1751
+ k_zp, k_scale,
1752
+ v_zp, v_scale,
1753
+ a_zp, a_scale,
1754
+ o_zp, o_scale);
1730
1755
}
1731
1756
} else {
1732
- if (! attn_mask.has_value () ) {
1757
+ AT_DISPATCH_MASK_TYPES ( attn_mask.value (). scalar_type (), " sdpa_mask " , [&]( ) {
1733
1758
if (q_split_size == 256 ) {
1734
- sdpa_int8_kernel_several_loops_impl<unsigned char , float , 256 , 64 >(
1759
+ sdpa_int8_fused_kernel_impl<unsigned char , mask_t , 256 , 64 >(
1760
+ use_one_parallel_loop,
1735
1761
output, query, key, value,
1736
1762
dropout_p, is_causal, attn_mask, scale,
1737
1763
q_zp, q_scale,
@@ -1740,7 +1766,8 @@ void sdpa_int8_fused_kernel(
1740
1766
a_zp, a_scale,
1741
1767
o_zp, o_scale);
1742
1768
} else if (q_split_size == 64 ) {
1743
- sdpa_int8_kernel_several_loops_impl<unsigned char , float , 64 , 64 >(
1769
+ sdpa_int8_fused_kernel_impl<unsigned char , mask_t , 64 , 64 >(
1770
+ use_one_parallel_loop,
1744
1771
output, query, key, value,
1745
1772
dropout_p, is_causal, attn_mask, scale,
1746
1773
q_zp, q_scale,
@@ -1749,7 +1776,8 @@ void sdpa_int8_fused_kernel(
1749
1776
a_zp, a_scale,
1750
1777
o_zp, o_scale);
1751
1778
} else {
1752
- sdpa_int8_kernel_several_loops_impl<unsigned char , float , 32 , 64 >(
1779
+ sdpa_int8_fused_kernel_impl<unsigned char , mask_t , 32 , 64 >(
1780
+ use_one_parallel_loop,
1753
1781
output, query, key, value,
1754
1782
dropout_p, is_causal, attn_mask, scale,
1755
1783
q_zp, q_scale,
@@ -1758,38 +1786,7 @@ void sdpa_int8_fused_kernel(
1758
1786
a_zp, a_scale,
1759
1787
o_zp, o_scale);
1760
1788
}
1761
- } else {
1762
- AT_DISPATCH_MASK_TYPES (attn_mask.value ().scalar_type (), " sdpa_mask" , [&]() {
1763
- if (q_split_size == 256 ) {
1764
- sdpa_int8_kernel_several_loops_impl<unsigned char , mask_t , 256 , 64 >(
1765
- output, query, key, value,
1766
- dropout_p, is_causal, attn_mask, scale,
1767
- q_zp, q_scale,
1768
- k_zp, k_scale,
1769
- v_zp, v_scale,
1770
- a_zp, a_scale,
1771
- o_zp, o_scale);
1772
- } else if (q_split_size == 64 ) {
1773
- sdpa_int8_kernel_several_loops_impl<unsigned char , mask_t , 64 , 64 >(
1774
- output, query, key, value,
1775
- dropout_p, is_causal, attn_mask, scale,
1776
- q_zp, q_scale,
1777
- k_zp, k_scale,
1778
- v_zp, v_scale,
1779
- a_zp, a_scale,
1780
- o_zp, o_scale);
1781
- } else {
1782
- sdpa_int8_kernel_several_loops_impl<unsigned char , mask_t , 32 , 64 >(
1783
- output, query, key, value,
1784
- dropout_p, is_causal, attn_mask, scale,
1785
- q_zp, q_scale,
1786
- k_zp, k_scale,
1787
- v_zp, v_scale,
1788
- a_zp, a_scale,
1789
- o_zp, o_scale);
1790
- }
1791
- });
1792
- }
1789
+ });
1793
1790
}
1794
1791
}
1795
1792
#endif // CPU_CAPABILITY_AVX512
@@ -1888,22 +1885,16 @@ at::Tensor _scaled_dot_product_int8_cpu(
1888
1885
o_zp, o_scale);
1889
1886
return output.transpose (1 , 2 );
1890
1887
} else {
1888
+ #endif // CPU_CAPABILITY_AVX512
1891
1889
return sdpa_int8_math_kernel (query, key, value,
1892
1890
dropout_p, is_causal, attn_mask, scale,
1893
1891
q_zp, q_scale,
1894
1892
k_zp, k_scale,
1895
1893
v_zp, v_scale,
1896
1894
a_zp, a_scale,
1897
1895
o_zp, o_scale).transpose (1 , 2 ).contiguous ().transpose (1 , 2 );
1896
+ #ifdef CPU_CAPABILITY_AVX512
1898
1897
}
1899
- #else
1900
- return sdpa_int8_math_kernel (query, key, value,
1901
- dropout_p, is_causal, attn_mask, scale,
1902
- q_zp, q_scale,
1903
- k_zp, k_scale,
1904
- v_zp, v_scale,
1905
- a_zp, a_scale,
1906
- o_zp, o_scale).transpose (1 , 2 ).contiguous ().transpose (1 , 2 );
1907
1898
#endif // CPU_CAPABILITY_AVX512
1908
1899
}
1909
1900
0 commit comments