Skip to content

metal lowbit kernels: optimized 2-bit, 3-bit and 4-bit shaders #1422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions torchao/experimental/kernels/mps/metal.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
- func: Vec4Type
file: common.metal

- func: int1mm
file: divbit.metal
file: int1mm.metal

- func: int2mm
file: divbit.metal
file: int2mm_opt.metal

- func: int3mm
file: int3mm.metal
file: int3mm_opt.metal

- func: int4mm
file: divbit.metal
file: int4mm_opt.metal

- func: int5mm
file: int5mm.metal
Expand Down
15 changes: 15 additions & 0 deletions torchao/experimental/kernels/mps/metal/common.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
template <typename T> struct Vec4Type {};

template <> struct Vec4Type<float> {
using type = float4;
};

template <> struct Vec4Type<half> {
using type = half4;
};

#if __METAL_VERSION__ >= 310
template <> struct Vec4Type<bfloat> {
using type = bfloat4;
};
#endif
109 changes: 0 additions & 109 deletions torchao/experimental/kernels/mps/metal/divbit.metal

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
using namespace metal;

/**
* 3-Bit Quantized Linear.
* 1-Bit Quantized Linear.
*
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (3 * K / 8)
* @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (K / 8)
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
* @param[outputData] M x N output tensor of floating point dtype (same as input)
Expand All @@ -14,7 +14,7 @@ using namespace metal;
* Dispatched threads: N x M x 1
*/
template<typename T, unsigned groupSize>
kernel void int3pack_mm(
kernel void int1pack_mm(
constant T * A [[buffer(0)]],
constant uchar * B [[buffer(1)]],
constant T * scales [[buffer(2)]],
Expand All @@ -28,7 +28,7 @@ kernel void int3pack_mm(
const uint n = thread_index.x; // 0..N-1
const uint32_t k_block = (K + groupSize - 1) / groupSize;
constant T *A_ptr = A + m * K;
constant uchar *B_ptr = B + n * 3 * K / 8;
constant uchar *B_ptr = B + n * K / 8;

float rc = 0.0;
uint k = 0;
Expand All @@ -45,19 +45,16 @@ kernel void int3pack_mm(
const auto a_val6 = float(A_ptr[k + 6]);
const auto a_val7 = float(A_ptr[k + 7]);

uchar b0 = B_ptr[3 * (k / 8) + 0];
uchar b1 = B_ptr[3 * (k / 8) + 1];
uchar b2 = B_ptr[3 * (k / 8) + 2];
uchar b0 = B_ptr[(k / 8)];

uchar w_val0 = ((b0 & 1) << 2) | (b1 & 3);
uchar w_val1 = ((b0 & 2) << 1) | ((b1 & 12) >> 2);
uchar w_val2 = (b0 & 4) | ((b1 & 48) >> 4);
uchar w_val3 = ((b0 & 8) >> 1) | ((b1 & 192) >> 6);

uchar w_val4 = ((b0 & 16) >> 2) | (b2 & 3);
uchar w_val5 = ((b0 & 32) >> 3) | ((b2 & 12) >> 2);
uchar w_val6 = ((b0 & 64) >> 4) | ((b2 & 48) >> 4);
uchar w_val7 = ((b0 & 128) >> 5) | ((b2 & 192) >> 6);
uchar w_val0 = b0 & 0x01;
uchar w_val1 = (b0 & 0x02) >> 1;
uchar w_val2 = (b0 & 0x04) >> 2;
uchar w_val3 = (b0 & 0x08) >> 3;
uchar w_val4 = (b0 & 0x10) >> 4;
uchar w_val5 = (b0 & 0x20) >> 5;
uchar w_val6 = (b0 & 0x40) >> 6;
uchar w_val7 = (b0 & 0x80) >> 7;

rc += a_val0 * (scale * float(w_val0) + zero);
rc += a_val1 * (scale * float(w_val1) + zero);
Expand All @@ -72,10 +69,10 @@ kernel void int3pack_mm(
outputData[m * N + n] = T(rc);
}

#define INSTANTIATE_INT3MM(DTYPE, GSIZE) \
#define INSTANTIATE_INT1MM(DTYPE, GSIZE) \
template \
[[host_name("int3pack_mm_" #GSIZE "_" #DTYPE)]] \
kernel void int3pack_mm<DTYPE, GSIZE>( \
[[host_name("int1pack_mm_" #GSIZE "_" #DTYPE)]] \
kernel void int1pack_mm<DTYPE, GSIZE>( \
constant DTYPE * A [[buffer(0)]], \
constant uchar * B [[buffer(1)]], \
constant DTYPE * scales [[buffer(2)]], \
Expand All @@ -84,17 +81,17 @@ kernel void int3pack_mm<DTYPE, GSIZE>( \
constant uint3 & sizes [[buffer(5)]], \
uint2 thread_index [[thread_position_in_grid]])

INSTANTIATE_INT3MM(float, 32);
INSTANTIATE_INT3MM(half, 32);
INSTANTIATE_INT3MM(float, 64);
INSTANTIATE_INT3MM(half, 64);
INSTANTIATE_INT3MM(float, 128);
INSTANTIATE_INT3MM(half, 128);
INSTANTIATE_INT3MM(float, 256);
INSTANTIATE_INT3MM(half, 256);
INSTANTIATE_INT1MM(float, 32);
INSTANTIATE_INT1MM(half, 32);
INSTANTIATE_INT1MM(float, 64);
INSTANTIATE_INT1MM(half, 64);
INSTANTIATE_INT1MM(float, 128);
INSTANTIATE_INT1MM(half, 128);
INSTANTIATE_INT1MM(float, 256);
INSTANTIATE_INT1MM(half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INT3MM(bfloat, 32);
INSTANTIATE_INT3MM(bfloat, 64);
INSTANTIATE_INT3MM(bfloat, 128);
INSTANTIATE_INT3MM(bfloat, 256);
INSTANTIATE_INT1MM(bfloat, 32);
INSTANTIATE_INT1MM(bfloat, 64);
INSTANTIATE_INT1MM(bfloat, 128);
INSTANTIATE_INT1MM(bfloat, 256);
#endif
138 changes: 138 additions & 0 deletions torchao/experimental/kernels/mps/metal/int2mm_opt.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
#include <metal_simdgroup>
#include <metal_stdlib>
using namespace metal;

/*
This code takes heavy inspiration from MLX:
https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/quantized.h
Specifically:
- Multiplying activation by inverse scaling factor to reduce compute
boundedness
- Handling zero point by accumulating act in separate sum term. Needed with
optimization done above. MLX MIT License:
https://github.com/ml-explore/mlx/blob/main/LICENSE
*/

/*
@brief This shader implements 2-bit matrix-vector multiplication where A
matrix is fp16, bfloat or float and B matrix is a 2-bit groupwise-quantized weight
matrix.
@param [in] A is activation matrix of size M x K.
@param [in] B is weight matrix of size M x K. Each byte contains 4 2-bit
values, along K dim, packed together.
@param [in] scales_ptr is scales ptr corresponding each
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output
channels.
@param [in] zeros_ptr is zero points corresponding each
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output
channels.
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N, 2]. N = output
@param [out] output_data is output matrix of size M x N.
@param [in] sizes array contains values of M, K and N.
@param [in] thread_index is global thread id.
@param [in] tid_in_simdgruop is thread id in simdgroup. e.g. in simdgroup of size 32 it can be in [0-31].
*/
template <typename T, unsigned group_size>
kernel void int2pack_mm(constant T *A [[buffer(0)]],
constant uchar *B [[buffer(1)]],
constant T *scales_ptr [[buffer(2)]],
constant T *zeros_ptr [[buffer(3)]],
device T *output_data [[buffer(4)]],
constant uint3 &sizes [[buffer(5)]], // M, K, N
uint3 thread_index [[thread_position_in_grid]],
uint tid_in_simdgroup [[thread_index_in_simdgroup]]) {
constexpr uint threads_per_channel = 32;
constexpr uint ks_per_thread = 4;
constexpr uint k_pack_factor = 4;
const uint K = sizes.y;
const uint N = sizes.z;
uint n = thread_index.x; // 0..N/4-1
uint m = thread_index.z; // 0..M
n = n / threads_per_channel;
n = n * 4;
// This is starting k for each thread. In the example above, for thread 1 this
// value will be 4.
uint k = (tid_in_simdgroup % threads_per_channel) * ks_per_thread;
constexpr int k_jump = threads_per_channel * ks_per_thread;

using vecT = typename Vec4Type<T>::type;
constant vecT *A_ptr = reinterpret_cast<constant vecT *>(A + m * K);
constant uchar *B_ptr = B + ((n * K) / k_pack_factor);

thread float4 result = float4(0.0);
// We multipy group of 4 channels with these scales.
// Because corresponding values from weight matrix are effectively left
// shifted. This is to avoid doing right shift on those values which ends up
// affecting performance. This is the trick applied in MLX kernels.
float4 act_div_scales = {1.f, 1 / 4.f, 1 / 16.f, 1 / 64.f};

for (; k < K; k += k_jump) {
// Find specific group to which channels handled by this thread
// belong.
uint k_block_index = k / group_size;
uint scales_group_offset = (k_block_index * N + n);

vecT scales =
(reinterpret_cast<constant vecT *>(scales_ptr + scales_group_offset))[0];
// Adding zero point results in 10% perf penalty.
vecT zeros =
(reinterpret_cast<constant vecT *>(zeros_ptr + scales_group_offset))[0];
float4 zeros_float = float4(zeros);

float4 a_val = float4(A_ptr[k / 4]);
// We are gonna skip right-shifts of the weights and hence divide by corresponding factor.
float4 a_vec = a_val * act_div_scales;
float a_val_sum = a_val[0] + a_val[1] + a_val[2] + a_val[3];

float4x4 b_mat;
ushort b_val0 = (B_ptr + (k + 0 * K) / k_pack_factor)[0];
ushort b_val1 = (B_ptr + (k + 1 * K) / k_pack_factor)[0];
ushort b_val2 = (B_ptr + (k + 2 * K) / k_pack_factor)[0];
ushort b_val3 = (B_ptr + (k + 3 * K) / k_pack_factor)[0];
b_mat[0] = scales[0] * float4(float(b_val0 & 0x03), float(b_val0 & 0x0c),
float(b_val0 & 0x30), float(b_val0 & 0xc0));
b_mat[1] = scales[1] * float4(float(b_val1 & 0x03), float(b_val1 & 0x0c),
float(b_val1 & 0x30), float(b_val1 & 0xc0));
b_mat[2] = scales[2] * float4(float(b_val2 & 0x03), float(b_val2 & 0x0c),
float(b_val2 & 0x30), float(b_val2 & 0xc0));
b_mat[3] = scales[3] * float4(float(b_val3 & 0x03), float(b_val3 & 0x0c),
float(b_val3 & 0x30), float(b_val3 & 0xc0));

result += a_vec * b_mat;
result += a_val_sum * zeros_float;
}
result += simd_shuffle_down(result, 1);
result += simd_shuffle_down(result, 2);
result += simd_shuffle_down(result, 4);
result += simd_shuffle_down(result, 8);
result += simd_shuffle_down(result, 16);
if (tid_in_simdgroup % threads_per_channel == 0) {
reinterpret_cast<device vecT *>(output_data + m * N)[n / 4] = vecT(result);
}
}

#define INSTANTIATE_INT2MM(DTYPE, GSIZE) \
template [[host_name("int2pack_mm_" #GSIZE "_" #DTYPE)]] kernel void \
int2pack_mm<DTYPE, GSIZE>( \
constant DTYPE * A [[buffer(0)]], constant uchar * B [[buffer(1)]], \
constant DTYPE * scales_ptr [[buffer(2)]], \
constant DTYPE * zeros_ptr [[buffer(3)]], \
device DTYPE * output_data [[buffer(4)]], \
constant uint3 & sizes [[buffer(5)]], \
uint3 thread_index [[thread_position_in_grid]], \
uint tid_in_simdgroup [[thread_index_in_simdgroup]])

INSTANTIATE_INT2MM(float, 32);
INSTANTIATE_INT2MM(half, 32);
INSTANTIATE_INT2MM(float, 64);
INSTANTIATE_INT2MM(half, 64);
INSTANTIATE_INT2MM(float, 128);
INSTANTIATE_INT2MM(half, 128);
INSTANTIATE_INT2MM(float, 256);
INSTANTIATE_INT2MM(half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INT2MM(bfloat, 32);
INSTANTIATE_INT2MM(bfloat, 64);
INSTANTIATE_INT2MM(bfloat, 128);
INSTANTIATE_INT2MM(bfloat, 256);
#endif
Loading
Loading