diff --git a/torchao/experimental/kernels/mps/metal.yaml b/torchao/experimental/kernels/mps/metal.yaml index 93b3a23c56..eb837432c7 100644 --- a/torchao/experimental/kernels/mps/metal.yaml +++ b/torchao/experimental/kernels/mps/metal.yaml @@ -1,14 +1,17 @@ +- func: Vec4Type + file: common.metal + - func: int1mm - file: divbit.metal + file: int1mm.metal - func: int2mm - file: divbit.metal + file: int2mm_opt.metal - func: int3mm - file: int3mm.metal + file: int3mm_opt.metal - func: int4mm - file: divbit.metal + file: int4mm_opt.metal - func: int5mm file: int5mm.metal diff --git a/torchao/experimental/kernels/mps/metal/common.metal b/torchao/experimental/kernels/mps/metal/common.metal new file mode 100644 index 0000000000..69fff8133e --- /dev/null +++ b/torchao/experimental/kernels/mps/metal/common.metal @@ -0,0 +1,15 @@ +template struct Vec4Type {}; + +template <> struct Vec4Type { + using type = float4; +}; + +template <> struct Vec4Type { + using type = half4; +}; + +#if __METAL_VERSION__ >= 310 +template <> struct Vec4Type { + using type = bfloat4; +}; +#endif diff --git a/torchao/experimental/kernels/mps/metal/divbit.metal b/torchao/experimental/kernels/mps/metal/divbit.metal deleted file mode 100644 index 5c8b146643..0000000000 --- a/torchao/experimental/kernels/mps/metal/divbit.metal +++ /dev/null @@ -1,109 +0,0 @@ -#include -using namespace metal; - -/** - * LowBit Quantized Linear for bitwidths that are divisors of 8. Hence the name. - * - * @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) - * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (nbit * K / 8) - * @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N - * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N - * @param[outputData] M x N output tensor of floating point dtype (same as input) - * @param[sizes] The sizes involved in the order: M, K, N - * - * Dispatched threads: N x M x 1 - */ -template -kernel void divbit_mm( - constant T * A [[buffer(0)]], - constant uchar * B [[buffer(1)]], - constant T * scales [[buffer(2)]], - constant T * zeros [[buffer(3)]], - device T * outputData [[buffer(4)]], - constant uint3 & sizes [[buffer(5)]], // M, K, N - uint2 thread_index [[thread_position_in_grid]]) { - const uint K = sizes.y; - const uint N = sizes.z; - const uint m = thread_index.y; // 0..M-1 - const uint n = thread_index.x; // 0..N-1 - const uint32_t k_block = (K + groupSize - 1) / groupSize; - constant T *A_ptr = A + m * K; - constant uchar *B_ptr = B; - - constexpr uint8_t zero_shift = 1 << (nbit - 1); - constexpr uint8_t values_per_byte = 8 / nbit; - constexpr uint8_t minimask = (1 << nbit) - 1; - - float rc = 0.0; - uint k = 0; - for (uint32_t kb = 0; kb < k_block ; kb ++) { - const float scale = float(scales[kb * N + n]); - const float zero = float(zeros[kb * N + n]); - for(uint idx = 0; idx < groupSize && k < K; idx++, k++) { - const auto a_val = float(A_ptr[k]); - uint8_t b_val = B_ptr[(n * K + k) / values_per_byte]; - uint8_t shift = nbit * (k % values_per_byte); - uint8_t mask = minimask << shift; - b_val = (b_val & mask) >> shift; - rc += a_val * (scale * float(b_val) + zero); - } - } - outputData[m * N + n] = T(rc); -} - -#define INSTANTIATE_DIVBIT_MM(NBIT, DTYPE, GSIZE) \ -template \ -[[host_name("int" #NBIT "pack_mm_" #GSIZE "_" #DTYPE)]] \ -kernel void divbit_mm( \ - constant DTYPE * A [[buffer(0)]], \ - constant uchar * B [[buffer(1)]], \ - constant DTYPE * scales [[buffer(2)]], \ - constant DTYPE * zeros [[buffer(3)]], \ - device DTYPE * outputData [[buffer(4)]], \ - constant uint3 & sizes [[buffer(5)]], \ - uint2 thread_index [[thread_position_in_grid]]) - -INSTANTIATE_DIVBIT_MM(1, float, 32); -INSTANTIATE_DIVBIT_MM(1, half, 32); -INSTANTIATE_DIVBIT_MM(1, float, 64); -INSTANTIATE_DIVBIT_MM(1, half, 64); -INSTANTIATE_DIVBIT_MM(1, float, 128); -INSTANTIATE_DIVBIT_MM(1, half, 128); -INSTANTIATE_DIVBIT_MM(1, float, 256); -INSTANTIATE_DIVBIT_MM(1, half, 256); -#if __METAL_VERSION__ >= 310 -INSTANTIATE_DIVBIT_MM(1, bfloat, 32); -INSTANTIATE_DIVBIT_MM(1, bfloat, 64); -INSTANTIATE_DIVBIT_MM(1, bfloat, 128); -INSTANTIATE_DIVBIT_MM(1, bfloat, 256); -#endif - -INSTANTIATE_DIVBIT_MM(2, float, 32); -INSTANTIATE_DIVBIT_MM(2, half, 32); -INSTANTIATE_DIVBIT_MM(2, float, 64); -INSTANTIATE_DIVBIT_MM(2, half, 64); -INSTANTIATE_DIVBIT_MM(2, float, 128); -INSTANTIATE_DIVBIT_MM(2, half, 128); -INSTANTIATE_DIVBIT_MM(2, float, 256); -INSTANTIATE_DIVBIT_MM(2, half, 256); -#if __METAL_VERSION__ >= 310 -INSTANTIATE_DIVBIT_MM(2, bfloat, 32); -INSTANTIATE_DIVBIT_MM(2, bfloat, 64); -INSTANTIATE_DIVBIT_MM(2, bfloat, 128); -INSTANTIATE_DIVBIT_MM(2, bfloat, 256); -#endif - -INSTANTIATE_DIVBIT_MM(4, float, 32); -INSTANTIATE_DIVBIT_MM(4, half, 32); -INSTANTIATE_DIVBIT_MM(4, float, 64); -INSTANTIATE_DIVBIT_MM(4, half, 64); -INSTANTIATE_DIVBIT_MM(4, float, 128); -INSTANTIATE_DIVBIT_MM(4, half, 128); -INSTANTIATE_DIVBIT_MM(4, float, 256); -INSTANTIATE_DIVBIT_MM(4, half, 256); -#if __METAL_VERSION__ >= 310 -INSTANTIATE_DIVBIT_MM(4, bfloat, 32); -INSTANTIATE_DIVBIT_MM(4, bfloat, 64); -INSTANTIATE_DIVBIT_MM(4, bfloat, 128); -INSTANTIATE_DIVBIT_MM(4, bfloat, 256); -#endif diff --git a/torchao/experimental/kernels/mps/metal/int3mm.metal b/torchao/experimental/kernels/mps/metal/int1mm.metal similarity index 68% rename from torchao/experimental/kernels/mps/metal/int3mm.metal rename to torchao/experimental/kernels/mps/metal/int1mm.metal index 4a44345b83..15b7f19138 100644 --- a/torchao/experimental/kernels/mps/metal/int3mm.metal +++ b/torchao/experimental/kernels/mps/metal/int1mm.metal @@ -2,10 +2,10 @@ using namespace metal; /** - * 3-Bit Quantized Linear. + * 1-Bit Quantized Linear. * - * @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) - * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (3 * K / 8) + * @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16) + * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (K / 8) * @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N * @param[outputData] M x N output tensor of floating point dtype (same as input) @@ -14,7 +14,7 @@ using namespace metal; * Dispatched threads: N x M x 1 */ template -kernel void int3pack_mm( +kernel void int1pack_mm( constant T * A [[buffer(0)]], constant uchar * B [[buffer(1)]], constant T * scales [[buffer(2)]], @@ -28,7 +28,7 @@ kernel void int3pack_mm( const uint n = thread_index.x; // 0..N-1 const uint32_t k_block = (K + groupSize - 1) / groupSize; constant T *A_ptr = A + m * K; - constant uchar *B_ptr = B + n * 3 * K / 8; + constant uchar *B_ptr = B + n * K / 8; float rc = 0.0; uint k = 0; @@ -45,19 +45,16 @@ kernel void int3pack_mm( const auto a_val6 = float(A_ptr[k + 6]); const auto a_val7 = float(A_ptr[k + 7]); - uchar b0 = B_ptr[3 * (k / 8) + 0]; - uchar b1 = B_ptr[3 * (k / 8) + 1]; - uchar b2 = B_ptr[3 * (k / 8) + 2]; + uchar b0 = B_ptr[(k / 8)]; - uchar w_val0 = ((b0 & 1) << 2) | (b1 & 3); - uchar w_val1 = ((b0 & 2) << 1) | ((b1 & 12) >> 2); - uchar w_val2 = (b0 & 4) | ((b1 & 48) >> 4); - uchar w_val3 = ((b0 & 8) >> 1) | ((b1 & 192) >> 6); - - uchar w_val4 = ((b0 & 16) >> 2) | (b2 & 3); - uchar w_val5 = ((b0 & 32) >> 3) | ((b2 & 12) >> 2); - uchar w_val6 = ((b0 & 64) >> 4) | ((b2 & 48) >> 4); - uchar w_val7 = ((b0 & 128) >> 5) | ((b2 & 192) >> 6); + uchar w_val0 = b0 & 0x01; + uchar w_val1 = (b0 & 0x02) >> 1; + uchar w_val2 = (b0 & 0x04) >> 2; + uchar w_val3 = (b0 & 0x08) >> 3; + uchar w_val4 = (b0 & 0x10) >> 4; + uchar w_val5 = (b0 & 0x20) >> 5; + uchar w_val6 = (b0 & 0x40) >> 6; + uchar w_val7 = (b0 & 0x80) >> 7; rc += a_val0 * (scale * float(w_val0) + zero); rc += a_val1 * (scale * float(w_val1) + zero); @@ -72,10 +69,10 @@ kernel void int3pack_mm( outputData[m * N + n] = T(rc); } -#define INSTANTIATE_INT3MM(DTYPE, GSIZE) \ +#define INSTANTIATE_INT1MM(DTYPE, GSIZE) \ template \ -[[host_name("int3pack_mm_" #GSIZE "_" #DTYPE)]] \ -kernel void int3pack_mm( \ +[[host_name("int1pack_mm_" #GSIZE "_" #DTYPE)]] \ +kernel void int1pack_mm( \ constant DTYPE * A [[buffer(0)]], \ constant uchar * B [[buffer(1)]], \ constant DTYPE * scales [[buffer(2)]], \ @@ -84,17 +81,17 @@ kernel void int3pack_mm( \ constant uint3 & sizes [[buffer(5)]], \ uint2 thread_index [[thread_position_in_grid]]) -INSTANTIATE_INT3MM(float, 32); -INSTANTIATE_INT3MM(half, 32); -INSTANTIATE_INT3MM(float, 64); -INSTANTIATE_INT3MM(half, 64); -INSTANTIATE_INT3MM(float, 128); -INSTANTIATE_INT3MM(half, 128); -INSTANTIATE_INT3MM(float, 256); -INSTANTIATE_INT3MM(half, 256); +INSTANTIATE_INT1MM(float, 32); +INSTANTIATE_INT1MM(half, 32); +INSTANTIATE_INT1MM(float, 64); +INSTANTIATE_INT1MM(half, 64); +INSTANTIATE_INT1MM(float, 128); +INSTANTIATE_INT1MM(half, 128); +INSTANTIATE_INT1MM(float, 256); +INSTANTIATE_INT1MM(half, 256); #if __METAL_VERSION__ >= 310 -INSTANTIATE_INT3MM(bfloat, 32); -INSTANTIATE_INT3MM(bfloat, 64); -INSTANTIATE_INT3MM(bfloat, 128); -INSTANTIATE_INT3MM(bfloat, 256); +INSTANTIATE_INT1MM(bfloat, 32); +INSTANTIATE_INT1MM(bfloat, 64); +INSTANTIATE_INT1MM(bfloat, 128); +INSTANTIATE_INT1MM(bfloat, 256); #endif diff --git a/torchao/experimental/kernels/mps/metal/int2mm_opt.metal b/torchao/experimental/kernels/mps/metal/int2mm_opt.metal new file mode 100644 index 0000000000..465b74ed2b --- /dev/null +++ b/torchao/experimental/kernels/mps/metal/int2mm_opt.metal @@ -0,0 +1,138 @@ +#include +#include +using namespace metal; + +/* + This code takes heavy inspiration from MLX: + https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/quantized.h + Specifically: + - Multiplying activation by inverse scaling factor to reduce compute + boundedness + - Handling zero point by accumulating act in separate sum term. Needed with + optimization done above. MLX MIT License: + https://github.com/ml-explore/mlx/blob/main/LICENSE +*/ + +/* + @brief This shader implements 2-bit matrix-vector multiplication where A + matrix is fp16, bfloat or float and B matrix is a 2-bit groupwise-quantized weight + matrix. + @param [in] A is activation matrix of size M x K. + @param [in] B is weight matrix of size M x K. Each byte contains 4 2-bit + values, along K dim, packed together. + @param [in] scales_ptr is scales ptr corresponding each + output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output + channels. + @param [in] zeros_ptr is zero points corresponding each + output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output + channels. + output channel x groups. These are packed as [num_groups = ceil(K / group_size), N, 2]. N = output + @param [out] output_data is output matrix of size M x N. + @param [in] sizes array contains values of M, K and N. + @param [in] thread_index is global thread id. + @param [in] tid_in_simdgruop is thread id in simdgroup. e.g. in simdgroup of size 32 it can be in [0-31]. +*/ +template +kernel void int2pack_mm(constant T *A [[buffer(0)]], + constant uchar *B [[buffer(1)]], + constant T *scales_ptr [[buffer(2)]], + constant T *zeros_ptr [[buffer(3)]], + device T *output_data [[buffer(4)]], + constant uint3 &sizes [[buffer(5)]], // M, K, N + uint3 thread_index [[thread_position_in_grid]], + uint tid_in_simdgroup [[thread_index_in_simdgroup]]) { + constexpr uint threads_per_channel = 32; + constexpr uint ks_per_thread = 4; + constexpr uint k_pack_factor = 4; + const uint K = sizes.y; + const uint N = sizes.z; + uint n = thread_index.x; // 0..N/4-1 + uint m = thread_index.z; // 0..M + n = n / threads_per_channel; + n = n * 4; + // This is starting k for each thread. In the example above, for thread 1 this + // value will be 4. + uint k = (tid_in_simdgroup % threads_per_channel) * ks_per_thread; + constexpr int k_jump = threads_per_channel * ks_per_thread; + + using vecT = typename Vec4Type::type; + constant vecT *A_ptr = reinterpret_cast(A + m * K); + constant uchar *B_ptr = B + ((n * K) / k_pack_factor); + + thread float4 result = float4(0.0); + // We multipy group of 4 channels with these scales. + // Because corresponding values from weight matrix are effectively left + // shifted. This is to avoid doing right shift on those values which ends up + // affecting performance. This is the trick applied in MLX kernels. + float4 act_div_scales = {1.f, 1 / 4.f, 1 / 16.f, 1 / 64.f}; + + for (; k < K; k += k_jump) { + // Find specific group to which channels handled by this thread + // belong. + uint k_block_index = k / group_size; + uint scales_group_offset = (k_block_index * N + n); + + vecT scales = + (reinterpret_cast(scales_ptr + scales_group_offset))[0]; + // Adding zero point results in 10% perf penalty. + vecT zeros = + (reinterpret_cast(zeros_ptr + scales_group_offset))[0]; + float4 zeros_float = float4(zeros); + + float4 a_val = float4(A_ptr[k / 4]); + // We are gonna skip right-shifts of the weights and hence divide by corresponding factor. + float4 a_vec = a_val * act_div_scales; + float a_val_sum = a_val[0] + a_val[1] + a_val[2] + a_val[3]; + + float4x4 b_mat; + ushort b_val0 = (B_ptr + (k + 0 * K) / k_pack_factor)[0]; + ushort b_val1 = (B_ptr + (k + 1 * K) / k_pack_factor)[0]; + ushort b_val2 = (B_ptr + (k + 2 * K) / k_pack_factor)[0]; + ushort b_val3 = (B_ptr + (k + 3 * K) / k_pack_factor)[0]; + b_mat[0] = scales[0] * float4(float(b_val0 & 0x03), float(b_val0 & 0x0c), + float(b_val0 & 0x30), float(b_val0 & 0xc0)); + b_mat[1] = scales[1] * float4(float(b_val1 & 0x03), float(b_val1 & 0x0c), + float(b_val1 & 0x30), float(b_val1 & 0xc0)); + b_mat[2] = scales[2] * float4(float(b_val2 & 0x03), float(b_val2 & 0x0c), + float(b_val2 & 0x30), float(b_val2 & 0xc0)); + b_mat[3] = scales[3] * float4(float(b_val3 & 0x03), float(b_val3 & 0x0c), + float(b_val3 & 0x30), float(b_val3 & 0xc0)); + + result += a_vec * b_mat; + result += a_val_sum * zeros_float; + } + result += simd_shuffle_down(result, 1); + result += simd_shuffle_down(result, 2); + result += simd_shuffle_down(result, 4); + result += simd_shuffle_down(result, 8); + result += simd_shuffle_down(result, 16); + if (tid_in_simdgroup % threads_per_channel == 0) { + reinterpret_cast(output_data + m * N)[n / 4] = vecT(result); + } +} + +#define INSTANTIATE_INT2MM(DTYPE, GSIZE) \ + template [[host_name("int2pack_mm_" #GSIZE "_" #DTYPE)]] kernel void \ + int2pack_mm( \ + constant DTYPE * A [[buffer(0)]], constant uchar * B [[buffer(1)]], \ + constant DTYPE * scales_ptr [[buffer(2)]], \ + constant DTYPE * zeros_ptr [[buffer(3)]], \ + device DTYPE * output_data [[buffer(4)]], \ + constant uint3 & sizes [[buffer(5)]], \ + uint3 thread_index [[thread_position_in_grid]], \ + uint tid_in_simdgroup [[thread_index_in_simdgroup]]) + +INSTANTIATE_INT2MM(float, 32); +INSTANTIATE_INT2MM(half, 32); +INSTANTIATE_INT2MM(float, 64); +INSTANTIATE_INT2MM(half, 64); +INSTANTIATE_INT2MM(float, 128); +INSTANTIATE_INT2MM(half, 128); +INSTANTIATE_INT2MM(float, 256); +INSTANTIATE_INT2MM(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT2MM(bfloat, 32); +INSTANTIATE_INT2MM(bfloat, 64); +INSTANTIATE_INT2MM(bfloat, 128); +INSTANTIATE_INT2MM(bfloat, 256); +#endif diff --git a/torchao/experimental/kernels/mps/metal/int3mm_opt.metal b/torchao/experimental/kernels/mps/metal/int3mm_opt.metal new file mode 100644 index 0000000000..713c05dfe8 --- /dev/null +++ b/torchao/experimental/kernels/mps/metal/int3mm_opt.metal @@ -0,0 +1,147 @@ +#include +#include +using namespace metal; + +inline void unpack_3bit(const uchar3 b, thread float* w) { + w[0] = float(((b[0] & 1) << 2) | (b[1] & 3)); + w[1] = float(((b[0] & 2) << 1) | ((b[1] & 12) >> 2)); + w[2] = float((b[0] & 4) | ((b[1] & 48) >> 4)); + w[3] = float(((b[0] & 8) >> 1) | ((b[1] & 192) >> 6)); + + w[4] = float(((b[0] & 16) >> 2) | (b[2] & 3)); + w[5] = float(((b[0] & 32) >> 3) | ((b[2] & 12) >> 2)); + w[6] = float(((b[0] & 64) >> 4) | ((b[2] & 48) >> 4)); + w[7] = float(((b[0] & 128) >> 5) | ((b[2] & 192) >> 6)); +} + +/** + * 3-Bit Quantized Linear. + * + * @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16) + * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (3 * K / 8) + * @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N + * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N + * @param[outputData] M x N output tensor of floating point dtype (same as input) + * @param[sizes] The sizes involved in the order: M, K, N + * + */ +template +kernel void int3pack_mm(constant T *A [[buffer(0)]], + constant uchar *B [[buffer(1)]], + constant T *scales_ptr [[buffer(2)]], + constant T *zeros_ptr [[buffer(3)]], + device T *output_data [[buffer(4)]], + constant uint3 &sizes [[buffer(5)]], // M, K, N + uint3 thread_index [[thread_position_in_grid]], + uint tid_in_simdgroup [[thread_index_in_simdgroup]]) { + constexpr uint threads_per_channel = 32; + constexpr uint ks_per_thread = 8; + constexpr uint bytes_per_pack = 3; + constexpr uint k_pack_factor = 8; + const uint K = sizes.y; + const uint N = sizes.z; + uint n = thread_index.x; // 0..N/4-1 + uint m = thread_index.z; // 0..M + n = n / threads_per_channel; + n = n * 4; + + // This is starting k for each thread. + uint k = (tid_in_simdgroup % threads_per_channel) * ks_per_thread; + constexpr int k_jump = threads_per_channel * ks_per_thread; + + using vecT = typename Vec4Type::type; + constant vecT *A_ptr = reinterpret_cast(A + m * K); + constant uchar *B_ptr = B + n * bytes_per_pack * K / k_pack_factor; + + thread float4 result = float4(0.0); + + for (; k < K; k += k_jump) { + // Find specific group to which channels handled by this thread + // belong. + uint k_block_index = k / group_size; + uint scales_group_offset = (k_block_index * N + n); + + vecT scales = + (reinterpret_cast(scales_ptr + scales_group_offset))[0]; + vecT zeros = + (reinterpret_cast(zeros_ptr + scales_group_offset))[0]; + float4 zeros_float = float4(zeros); + + float4 a_val[2]; + a_val[0] = float4(A_ptr[k / 4]); + a_val[1] = float4(A_ptr[k / 4 + 1]); + + float a_val_sum = a_val[0][0] + a_val[0][1] + a_val[0][2] + a_val[0][3]; + a_val_sum += a_val[1][0] + a_val[1][1] + a_val[1][2] + a_val[1][3]; + + uchar3 b_val0 = (reinterpret_cast( + B_ptr + bytes_per_pack * (k + 0 * K) / k_pack_factor))[0]; + uchar3 b_val1 = (reinterpret_cast( + B_ptr + bytes_per_pack * (k + 1 * K) / k_pack_factor))[0]; + uchar3 b_val2 = (reinterpret_cast( + B_ptr + bytes_per_pack * (k + 2 * K) / k_pack_factor))[0]; + uchar3 b_val3 = (reinterpret_cast( + B_ptr + bytes_per_pack * (k + 3 * K) / k_pack_factor))[0]; + + float4x4 b_mat[2]; + + thread float w0[8]; + unpack_3bit(b_val0, w0); + + thread float w1[8]; + unpack_3bit(b_val1, w1); + + thread float w2[8]; + unpack_3bit(b_val2, w2); + + thread float w3[8]; + unpack_3bit(b_val3, w3); + + b_mat[0][0] = scales[0] * float4(w0[0], w0[1], w0[2], w0[3]), + b_mat[1][0] = scales[0] * float4(w0[4], w0[5], w0[6], w0[7]), + b_mat[0][1] = scales[1] * float4(w1[0], w1[1], w1[2], w1[3]), + b_mat[1][1] = scales[1] * float4(w1[4], w1[5], w1[6], w1[7]), + b_mat[0][2] = scales[2] * float4(w2[0], w2[1], w2[2], w2[3]), + b_mat[1][2] = scales[2] * float4(w2[4], w2[5], w2[6], w2[7]), + b_mat[0][3] = scales[3] * float4(w3[0], w3[1], w3[2], w3[3]), + b_mat[1][3] = scales[3] * float4(w3[4], w3[5], w3[6], w3[7]), + + result += a_val[0] * b_mat[0]; + result += a_val[1] * b_mat[1]; + result += a_val_sum * zeros_float; + } + result += simd_shuffle_down(result, 1); + result += simd_shuffle_down(result, 2); + result += simd_shuffle_down(result, 4); + result += simd_shuffle_down(result, 8); + result += simd_shuffle_down(result, 16); + if (tid_in_simdgroup % threads_per_channel == 0) { + reinterpret_cast(output_data + m * N)[n / 4] = vecT(result); + } +} + +#define INSTANTIATE_INT3MM(DTYPE, GSIZE) \ + template [[host_name("int3pack_mm_" #GSIZE "_" #DTYPE)]] kernel void \ + int3pack_mm( \ + constant DTYPE * A [[buffer(0)]], constant uchar * B [[buffer(1)]], \ + constant DTYPE * scales_ptr [[buffer(2)]], \ + constant DTYPE * zeros_ptr [[buffer(3)]], \ + device DTYPE * output_data [[buffer(4)]], \ + constant uint3 & sizes [[buffer(5)]], \ + uint3 thread_index [[thread_position_in_grid]], \ + uint tid_in_simdgroup [[thread_index_in_simdgroup]]) + +INSTANTIATE_INT3MM(float, 32); +INSTANTIATE_INT3MM(half, 32); +INSTANTIATE_INT3MM(float, 64); +INSTANTIATE_INT3MM(half, 64); +INSTANTIATE_INT3MM(float, 128); +INSTANTIATE_INT3MM(half, 128); +INSTANTIATE_INT3MM(float, 256); +INSTANTIATE_INT3MM(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT3MM(bfloat, 32); +INSTANTIATE_INT3MM(bfloat, 64); +INSTANTIATE_INT3MM(bfloat, 128); +INSTANTIATE_INT3MM(bfloat, 256); +#endif diff --git a/torchao/experimental/kernels/mps/metal/int4mm_opt.metal b/torchao/experimental/kernels/mps/metal/int4mm_opt.metal new file mode 100644 index 0000000000..4d9f18f984 --- /dev/null +++ b/torchao/experimental/kernels/mps/metal/int4mm_opt.metal @@ -0,0 +1,180 @@ +#include +#include +using namespace metal; + +/* + This code takes heavy inspiration from MLX: + https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/quantized.h + Specifically: + - Multiplying activation by inverse scaling factor to reduce compute + boundedness + - Handling zero point by accumulating act in separate sum term. Needed with + optimization done above. MLX MIT License: + https://github.com/ml-explore/mlx/blob/main/LICENSE +*/ + +/* + A matrix is [M x K] (right now this kernel does not support M > 1 but this is + a very easy fix that will follow right after) B matrix is [N x K]. For 4 bit + 2 of the k values are packed in one byte so you can think of B as [N x K/2] + matrix from layout perspective. + + Since this kernel is optimizing for gemv case, we split work, along reduction + dim k, among the threads of same simdgroup. Ex: if K = 4096 and simdgroup + size is 32 (current algorithm should work as long as simdgroup size is > 32). + Then each thread will accumulate 4096/32 = 128 k values. However these 128 + values, handled by each thread are not laid out contiguously. Each thread + handles 4 contiguous k values and then jumps 128 elements, k_jump = + thread_per_channel (32) * ks_per_thread (4). Take a simpler example where + simdgroup is of size 4. In this case threads_per_channel = 4. Assume K = 32 + k thread + [0, 1, 2, 3, 0 + 4, 5, 6, 7, 1 + 8, 9, 10, 11, 2 + 12, 13, 14, 15, 3 + 16, 17, 18, 19, 0 + 20, 21, 22, 23, 1 + 24, 25, 26, 27, 2 + 28, 29, 30, 31] 3 + thread id in simd group that handle corresponding + ks + Thread 0 here is handling (0, 1, 2, 3) and then (16, 17, 18, 19). They are + apart by k_jump = 4 * 4 = 16 This is done to improve memory access locality + amonng threads that are working co-operatively. Once each thread has their + partial sums accumulated, we use tree reduction (Metal offers simd_sum but + not used so that we support simdgroup size = 64). In the + example above we will have 4 partial sums. + + Each thread also handles 4 different output rows. Thus each simdgroup will be + responsible for (1x4) tile of the output. We haven't evaluated whether a + different tile size is better or not. We probably will do some auto-tuning + once initial work is done. +*/ + +/* + @brief This shader implements 4-bit matrix-vector multiplication where A + matrix is fp16, bfloat or float and B matrix is a 4-bit groupwise-quantized weight + matrix. + @param [in] A is activation matrix of size M x K. + @param [in] B is weight matrix of size M x K. Each byte contains 2 4-bit + values, along K dim, packed together. + @param [in] scales_ptr is scales ptr corresponding each + output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output + channels. + @param [in] zeros_ptr is zero points corresponding each + output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output + channels. + output channel x groups. These are packed as [num_groups = ceil(K / group_size), N, 2]. N = output + @param [out] output_data is output matrix of size M x N. + @param [in] sizes array contains values of M, K and N. + @param [in] thread_index is global thread id. + @param [in] tid_in_simdgruop is thread id in simdgroup. e.g. in simdgroup of size 32 it can be in [0-31]. +*/ +template +kernel void int4pack_mm(constant T *A [[buffer(0)]], + constant uchar *B [[buffer(1)]], + constant T *scales_ptr [[buffer(2)]], + constant T *zeros_ptr [[buffer(3)]], + device T *output_data [[buffer(4)]], + constant uint3 &sizes [[buffer(5)]], // M, K, N + uint3 thread_index [[thread_position_in_grid]], + uint tid_in_simdgroup [[thread_index_in_simdgroup]]) { + constexpr uint threads_per_channel = 32; + constexpr uint ks_per_thread = 4; + constexpr uint k_pack_factor = 2; + const uint K = sizes.y; + const uint N = sizes.z; + uint n = thread_index.x; // 0..N/4-1 + uint m = thread_index.z; // 0..M + n = n / threads_per_channel; + n = n * 4; + // This is starting k for each thread. In the example above, for thread 1 this + // value will be 4. + uint k = (tid_in_simdgroup % threads_per_channel) * ks_per_thread; + constexpr int k_jump = threads_per_channel * ks_per_thread; + + using vecT = typename Vec4Type::type; + constant vecT *A_ptr = reinterpret_cast(A + m * K); + constant uchar *B_ptr = B + ((n * K) / k_pack_factor); + + thread float4 result = float4(0.0); + // We multipy group of 4 channels with these scales. + // Because corresponding values from weight matrix are effectively left + // shifted. This is to avoid doing right shift on those values which ends up + // affecting performance. This is the trick applied in MLX kernels. + float4 act_div_scales = {1.f, 1 / 16.f, 1 / 256.f, 1 / 4096.f}; + + for (; k < K; k += k_jump) { + // Find specific group to which channels handled by this thread + // belong. + uint k_block_index = k / group_size; + uint scales_group_offset = (k_block_index * N + n); + + vecT scales = + (reinterpret_cast(scales_ptr + scales_group_offset))[0]; + // Adding zero point results in 10% perf penalty. + vecT zeros = + (reinterpret_cast(zeros_ptr + scales_group_offset))[0]; + float4 zeros_float = float4(zeros); + + float4 a_val = float4(A_ptr[k / 4]); + // We are gonna skip right-shifts of the weights and hence divide by corresponding factor. + float4 a_vec = a_val * act_div_scales; + float a_val_sum = a_val[0] + a_val[1] + a_val[2] + a_val[3]; + + float4x4 b_mat; + ushort b_val0 = (reinterpret_cast( + B_ptr + (k + 0 * K) / k_pack_factor))[0]; + ushort b_val1 = (reinterpret_cast( + B_ptr + (k + 1 * K) / k_pack_factor))[0]; + ushort b_val2 = (reinterpret_cast( + B_ptr + (k + 2 * K) / k_pack_factor))[0]; + ushort b_val3 = (reinterpret_cast( + B_ptr + (k + 3 * K) / k_pack_factor))[0]; + b_mat[0] = scales[0] * float4(float(b_val0 & 0x000f), float(b_val0 & 0x00f0), + float(b_val0 & 0x0f00), float(b_val0 & 0xf000)); + b_mat[1] = scales[1] * float4(float(b_val1 & 0x000f), float(b_val1 & 0x00f0), + float(b_val1 & 0x0f00), float(b_val1 & 0xf000)); + b_mat[2] = scales[2] * float4(float(b_val2 & 0x000f), float(b_val2 & 0x00f0), + float(b_val2 & 0x0f00), float(b_val2 & 0xf000)); + b_mat[3] = scales[3] * float4(float(b_val3 & 0x000f), float(b_val3 & 0x00f0), + float(b_val3 & 0x0f00), float(b_val3 & 0xf000)); + + result += a_vec * b_mat; + result += a_val_sum * zeros_float; + } + result += simd_shuffle_down(result, 1); + result += simd_shuffle_down(result, 2); + result += simd_shuffle_down(result, 4); + result += simd_shuffle_down(result, 8); + result += simd_shuffle_down(result, 16); + if (tid_in_simdgroup % threads_per_channel == 0) { + reinterpret_cast(output_data + m * N)[n / 4] = vecT(result); + } +} + +#define INSTANTIATE_INT4MM(DTYPE, GSIZE) \ + template [[host_name("int4pack_mm_" #GSIZE "_" #DTYPE)]] kernel void \ + int4pack_mm( \ + constant DTYPE * A [[buffer(0)]], constant uchar * B [[buffer(1)]], \ + constant DTYPE * scales_ptr [[buffer(2)]], \ + constant DTYPE * zeros_ptr [[buffer(3)]], \ + device DTYPE * output_data [[buffer(4)]], \ + constant uint3 & sizes [[buffer(5)]], \ + uint3 thread_index [[thread_position_in_grid]], \ + uint tid_in_simdgroup [[thread_index_in_simdgroup]]) + +INSTANTIATE_INT4MM(float, 32); +INSTANTIATE_INT4MM(half, 32); +INSTANTIATE_INT4MM(float, 64); +INSTANTIATE_INT4MM(half, 64); +INSTANTIATE_INT4MM(float, 128); +INSTANTIATE_INT4MM(half, 128); +INSTANTIATE_INT4MM(float, 256); +INSTANTIATE_INT4MM(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT4MM(bfloat, 32); +INSTANTIATE_INT4MM(bfloat, 64); +INSTANTIATE_INT4MM(bfloat, 128); +INSTANTIATE_INT4MM(bfloat, 256); +#endif diff --git a/torchao/experimental/kernels/mps/metal/int5mm.metal b/torchao/experimental/kernels/mps/metal/int5mm.metal index d854b7f90e..0298ba5482 100644 --- a/torchao/experimental/kernels/mps/metal/int5mm.metal +++ b/torchao/experimental/kernels/mps/metal/int5mm.metal @@ -4,7 +4,7 @@ using namespace metal; /** * 5-Bit Quantized Linear. * - * @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) + * @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16) * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (5 * K / 8) * @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N diff --git a/torchao/experimental/kernels/mps/metal/int6mm.metal b/torchao/experimental/kernels/mps/metal/int6mm.metal index a43f5c0e0a..d8d874c46e 100644 --- a/torchao/experimental/kernels/mps/metal/int6mm.metal +++ b/torchao/experimental/kernels/mps/metal/int6mm.metal @@ -4,7 +4,7 @@ using namespace metal; /** * 6-Bit Quantized Linear. * - * @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) + * @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16) * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (6 * K / 8) * @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N @@ -35,25 +35,44 @@ kernel void int6pack_mm( for (uint32_t kb = 0; kb < k_block ; kb ++) { const float scale = float(scales[kb * N + n]); const float zero = float(zeros[kb * N + n]); - for(uint idx = 0; idx < groupSize && k < K; idx+=4, k+=4) { + for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) { const auto a_val0 = float(A_ptr[k + 0]); const auto a_val1 = float(A_ptr[k + 1]); const auto a_val2 = float(A_ptr[k + 2]); const auto a_val3 = float(A_ptr[k + 3]); + const auto a_val4 = float(A_ptr[k + 4]); + const auto a_val5 = float(A_ptr[k + 5]); + const auto a_val6 = float(A_ptr[k + 6]); + const auto a_val7 = float(A_ptr[k + 7]); + uchar b0 = B_ptr[3 * (k / 4) + 0]; uchar b1 = B_ptr[3 * (k / 4) + 1]; uchar b2 = B_ptr[3 * (k / 4) + 2]; + uchar b3 = B_ptr[3 * (k / 4) + 3]; + uchar b4 = B_ptr[3 * (k / 4) + 4]; + uchar b5 = B_ptr[3 * (k / 4) + 5]; + uchar w_val0 = ((b0 & 3) << 4) | (b1 & 15); uchar w_val1 = ((b0 & 12) << 2) | ((b1 & 240) >> 4); uchar w_val2 = ((b0 & 48)) | (b2 & 15); uchar w_val3 = ((b0 & 192) >> 2) | ((b2 & 240) >> 4); + uchar w_val4 = ((b3 & 3) << 4) | (b4 & 15); + uchar w_val5 = ((b3 & 12) << 2) | ((b4 & 240) >> 4); + uchar w_val6 = ((b3 & 48)) | (b5 & 15); + uchar w_val7 = ((b3 & 192) >> 2) | ((b5 & 240) >> 4); + rc += a_val0 * (scale * float(w_val0) + zero); rc += a_val1 * (scale * float(w_val1) + zero); rc += a_val2 * (scale * float(w_val2) + zero); rc += a_val3 * (scale * float(w_val3) + zero); + + rc += a_val4 * (scale * float(w_val4) + zero); + rc += a_val5 * (scale * float(w_val5) + zero); + rc += a_val6 * (scale * float(w_val6) + zero); + rc += a_val7 * (scale * float(w_val7) + zero); } } outputData[m * N + n] = T(rc); diff --git a/torchao/experimental/kernels/mps/metal/int7mm.metal b/torchao/experimental/kernels/mps/metal/int7mm.metal index 57c74402d9..dbde6e4e95 100644 --- a/torchao/experimental/kernels/mps/metal/int7mm.metal +++ b/torchao/experimental/kernels/mps/metal/int7mm.metal @@ -4,7 +4,7 @@ using namespace metal; /** * 7-Bit Quantized Linear. * - * @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) + * @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16) * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (7 * K / 8) * @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N diff --git a/torchao/experimental/kernels/mps/src/dispatch.h b/torchao/experimental/kernels/mps/src/dispatch.h index 5a48506a82..39acd8d1f0 100644 --- a/torchao/experimental/kernels/mps/src/dispatch.h +++ b/torchao/experimental/kernels/mps/src/dispatch.h @@ -20,4 +20,18 @@ inline void dispatch_mm( threadsPerThreadgroup:MTLSizeMake(std::min(maxThreadsPerGroup, M), 1, 1)]; } +inline void dispatch_mm_Mr1xNr4_per_TG( + id encoder, + int32_t maxThreadsPerGroup, + int32_t M, + int32_t N, + int32_t K) { + (void)K; + if (maxThreadsPerGroup < 32) { + throw std::runtime_error("Can't dispatch!"); + } + [encoder dispatchThreads:MTLSizeMake(N / 4 * 32, 1, M) + threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; +} + } // namespace torchao::kernels::mps::lowbit::dispatch diff --git a/torchao/experimental/kernels/mps/src/lowbit.h b/torchao/experimental/kernels/mps/src/lowbit.h index ae3951e217..9b2d539761 100644 --- a/torchao/experimental/kernels/mps/src/lowbit.h +++ b/torchao/experimental/kernels/mps/src/lowbit.h @@ -31,21 +31,21 @@ template <> struct LowBitConfig<2> { static constexpr std::string_view func_prefix = "int2pack_mm_"; static constexpr auto packing_fn = packing::pack<2>; - static constexpr auto dispatch_fn = dispatch::dispatch_mm; + static constexpr auto dispatch_fn = dispatch::dispatch_mm_Mr1xNr4_per_TG; }; template <> struct LowBitConfig<3> { static constexpr std::string_view func_prefix = "int3pack_mm_"; static constexpr auto packing_fn = packing::pack<3>; - static constexpr auto dispatch_fn = dispatch::dispatch_mm; + static constexpr auto dispatch_fn = dispatch::dispatch_mm_Mr1xNr4_per_TG; }; template <> struct LowBitConfig<4> { static constexpr std::string_view func_prefix = "int4pack_mm_"; static constexpr auto packing_fn = packing::pack<4>; - static constexpr auto dispatch_fn = dispatch::dispatch_mm; + static constexpr auto dispatch_fn = dispatch::dispatch_mm_Mr1xNr4_per_TG; }; template <> @@ -125,6 +125,7 @@ void linear_lowbit_quant_weights_mps( int32_t N, const std::string_view type_str) { assert(K % 8 == 0); + assert(N % 4 == 0); assert( qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 || qGroupSize == 256); diff --git a/torchao/experimental/kernels/mps/test/test_lowbit.mm b/torchao/experimental/kernels/mps/test/test_lowbit.mm index 7fb20d254a..8a1e0fdb9e 100644 --- a/torchao/experimental/kernels/mps/test/test_lowbit.mm +++ b/torchao/experimental/kernels/mps/test/test_lowbit.mm @@ -200,22 +200,23 @@ void run_test(int32_t m, int32_t k, int32_t n, int32_t group_size) { template void run_test_battery() { - run_test(1, 8, 1, 32); - run_test(1, 32, 1, 32); - run_test(1, 32, 1, 64); - run_test(1, 56, 1, 64); - run_test(1, 64, 1, 64); - run_test(1, 72, 1, 64); - run_test(1, 1000, 1, 64); - run_test(3, 64, 5, 64); - run_test(7, 64, 23, 64); - run_test(17, 120, 23, 128); - run_test(17, 128, 23, 128); - run_test(41, 144, 23, 128); - run_test(41, 128, 23, 128); - run_test(81, 8, 1, 256); - run_test(19, 256, 17, 256); - run_test(1, 1000, 81, 256); + run_test(1, 8, 4, 32); + run_test(1, 32, 4, 32); + run_test(1, 32, 4, 64); + run_test(1, 56, 4, 64); + run_test(1, 64, 4, 64); + run_test(1, 72, 4, 64); + run_test(1, 1000, 4, 64); + run_test(3, 64, 8, 64); + run_test(7, 64, 20, 64); + run_test(17, 120, 20, 128); + run_test(17, 128, 20, 128); + run_test(41, 144, 20, 128); + run_test(41, 128, 20, 128); + run_test(81, 8, 4, 256); + run_test(19, 256, 28, 256); + run_test(1, 1000, 28, 256); + run_test(19, 8, 36, 256); } int main() { diff --git a/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm index e11e55c5a0..162b5ab83c 100644 --- a/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm +++ b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm @@ -45,6 +45,8 @@ void check_linear_mps_args( TORCH_CHECK(K % 8 == 0, __func__, ": expect K to be multiple of 8, got ", K); + TORCH_CHECK(N % 4 == 0, __func__, ": expect N to be multiple of 4, got ", N); + TORCH_CHECK( group_size == 32 || group_size == 64 || group_size == 128 || group_size == 256, diff --git a/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm index 2892a67245..a6f417b17d 100644 --- a/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm +++ b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm @@ -57,6 +57,8 @@ bool check_linear_mps_args( ET_LOG_MSG_AND_RETURN_IF_FALSE(K % 8 == 0, "Expect K to be multiple of 8"); + ET_LOG_MSG_AND_RETURN_IF_FALSE(N % 4 == 0, "Expect N to be multiple of 4"); + ET_LOG_MSG_AND_RETURN_IF_FALSE( group_size == 32 || group_size == 64 || group_size == 128 || group_size == 256, diff --git a/torchao/experimental/ops/mps/test/test_lowbit.py b/torchao/experimental/ops/mps/test/test_lowbit.py index acff5624c8..27158fd7e3 100644 --- a/torchao/experimental/ops/mps/test/test_lowbit.py +++ b/torchao/experimental/ops/mps/test/test_lowbit.py @@ -39,22 +39,23 @@ class TestLowBitQuantWeightsLinear(unittest.TestCase): (nbit, *param) for nbit in range(1, 8) for param in [ - (1, 8, 1, 32), - (1, 32, 1, 32), - (1, 32, 1, 64), - (1, 56, 1, 64), - (1, 64, 1, 64), - (1, 72, 1, 64), - (1, 1000, 1, 64), - (3, 64, 5, 64), - (7, 64, 23, 64), - (17, 120, 23, 128), - (17, 128, 23, 128), - (41, 144, 23, 128), - (41, 128, 23, 128), - (81, 8, 1, 256), - (19, 256, 17, 256), - (1, 1000, 81, 256), + (1, 8, 4, 32), + (1, 32, 4, 32), + (1, 32, 4, 64), + (1, 56, 4, 64), + (1, 64, 4, 64), + (1, 72, 4, 64), + (1, 1000, 4, 64), + (3, 64, 8, 64), + (7, 64, 20, 64), + (17, 120, 20, 128), + (17, 128, 20, 128), + (41, 144, 20, 128), + (41, 128, 20, 128), + (81, 8, 4, 256), + (19, 256, 28, 256), + (1, 1000, 28, 256), + (19, 8, 36, 256), ] ] diff --git a/torchao/experimental/ops/mps/test/test_quantizer.py b/torchao/experimental/ops/mps/test/test_quantizer.py index 5b3331c6a8..72a6b76fad 100644 --- a/torchao/experimental/ops/mps/test/test_quantizer.py +++ b/torchao/experimental/ops/mps/test/test_quantizer.py @@ -53,7 +53,7 @@ def _model_setup(self): k0 = 96 k1 = 224 k2 = 160 - n = 47 + n = 44 layers = [ torch.nn.Linear(k0, k1, bias=False), torch.nn.Linear(k1, k2, bias=False), @@ -115,7 +115,7 @@ def test_3d_output_device_and_shape(self, nbit): def test_valid_groupsizes(self, nbit, group_size): k0 = 3 * group_size k1 = 7 * group_size - n = 47 + n = 44 layers = [ torch.nn.Linear(k0, k1, bias=False), torch.nn.Linear(k1, n, bias=False), @@ -134,7 +134,7 @@ def test_invalid_groupsizes(self, nbit): group_size = 16 k0 = 3 * group_size k1 = 7 * group_size - n = 47 + n = 44 layers = [ torch.nn.Linear(k0, k1, bias=False), torch.nn.Linear(k1, n, bias=False), @@ -158,7 +158,7 @@ def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z): def test_accuracy(self, nbit): group_size = 32 m = 3 - n = 7 + n = 12 k = 64 with torch.no_grad(): activations = torch.rand(m, k, dtype=torch.float32, device="mps")