diff --git a/test/test_ops.py b/test/test_ops.py index 28e7437b66..45a10abe3a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,3 +1,7 @@ +import itertools + +import torchao + import torch from torch.testing._internal.common_utils import ( TestCase, @@ -6,7 +10,7 @@ run_tests, ) from torch.testing._internal.optests import opcheck -from torchao.utils import is_fbcode +from torchao.utils import is_fbcode, TORCH_VERSION_AFTER_2_5 from torchao.prototype.quant_llm import from_scaled_tc_fpx import pytest @@ -18,6 +22,14 @@ except RuntimeError: pytest.skip("torchao.ops not available") +from torchao.quantization.utils import ( + get_groupwise_affine_qparams, + groupwise_affine_dequantize_tensor_from_qparams, + groupwise_affine_quantize_tensor_from_qparams, + pack_tinygemm_scales_and_zeros, + unpack_tinygemm_scales_and_zeros, +) + class TestOps(TestCase): def _create_fpx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device): @@ -61,9 +73,218 @@ def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK): relative_error = error / gt assert relative_error < 1e-3 - instantiate_parametrized_tests(TestOps) +## Tests for `tensor_core_layout` +kTileSizeN = 8 +kTileSizeK = 16 + +SHAPES = [ + (4096, 4096), + # Llama 2 GEMM shapes + (4096, 11008), + (11008, 4096), + # Llama 3 GEMM shapes + (4096, 14336), + (14336, 4096), +] +INNERKTILES = [2, 4, 8] +QGROUP_SIZES = [32, 64, 128, 256] +TEST_CONFIGS_UNPACK = list(itertools.product(SHAPES, INNERKTILES)) +TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES)) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=str) +def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): + N, K = shape + assert K % (inner_k_tiles * kTileSizeK) == 0 and N % kTileSizeN == 0 + + t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") + packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) + unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, inner_k_tiles) + assert torch.equal(t, unpacked) + +# TODO: Fix "test_aot_dispatch_dynamic" test failure +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK , ids=str) +def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): + test_utils = [ + "test_schema", + "test_autograd_registration", + "test_faketensor", + ] + + # TODO: Figure out why test fails unless torch >= 2.5 + if TORCH_VERSION_AFTER_2_5: + test_utils.append("test_aot_dispatch_dynamic") + + t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") + packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) + + opcheck( + torch.ops.torchao.unpack_tensor_core_tiled_layout, + (packed_w, inner_k_tiles), + test_utils=test_utils, + ) + +def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): + n, k = q.shape + assert q.dtype == torch.int + + n_groups = k // group_size + assert scales.shape[0] == n and scales.shape[1] == n_groups + assert scales.shape == zeros.shape + + midpoint = 2 ** (nbits - 1) + + #Convert fron u4 -> s4 and upcast to bfloat16 + q = q.sub(midpoint).to(dtype) + + # Dequantize + q = q.reshape(-1, group_size) + dq = q * scales.reshape(-1, 1) + zeros.reshape(-1, 1) + + return dq.reshape(n, k) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) +def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, inner_k_tiles, group_size): + n, k = shape + dtype = torch.bfloat16 + + device = "cuda" + + t = torch.randn(n, k, dtype=dtype, device=device) + scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype) + + # Quantize + q = groupwise_affine_quantize_tensor_from_qparams( + t, scales, zeros, n_bit=4, groupsize=group_size + ) + + # Pack to tensor core layout + packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles) + scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) + q_groups = k // group_size + assert scales_and_zeros.shape == torch.Size([q_groups, n, 2]) + + # Dequantize 'ao' ref + dq_ao = groupwise_affine_dequantize_tensor_from_qparams( + q, scales, zeros, n_bit=4, groupsize=group_size + ) + + # Dequantize by passing in an identity matrix as the activation + a_eye = torch.eye(k, device=device, dtype=dtype) + dq_id = torch.ops.aten._weight_int4pack_mm( + a_eye, + packed, + group_size, + scales_and_zeros, + ).t() + + # Actual operation to test + dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles) + + # Compare results + diff_ao_id = (dq_id - dq_ao).abs().max() + diff_op_id = (dq_op - dq_id).abs().max() + diff_op_ao = (dq_op - dq_ao).abs().max() + + # There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize` + # Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast + # conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are + # expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`. + + # Test that the `dequant` kernel gives same results as identity matrix-based dequant + assert diff_op_id == 0 + + # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix + assert diff_op_ao == diff_ao_id + + assert diff_op_ao < 1e-1 + +# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) +def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shape, inner_k_tiles, group_size): + n, k = shape + dtype = torch.bfloat16 + device = "cuda" + + # Quantize and pack + t = torch.randn(n, k, dtype=dtype, device=device) + scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype) + q = groupwise_affine_quantize_tensor_from_qparams( + t, scales, zeros, n_bit=4, groupsize=group_size + ) + + packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles) + scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) + + # Unpack and dequantize + unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed, inner_k_tiles) + dq_ao = groupwise_affine_dequantize_tensor_from_qparams( + unpacked, scales, zeros, n_bit=4, groupsize=group_size + ) + + # Dequantize by passing in an identity matrix as the activation + a_eye = torch.eye(k, device=device, dtype=dtype) + dq_id = torch.ops.aten._weight_int4pack_mm( + a_eye, + packed, + group_size, + scales_and_zeros, + ).t() + + # Actual operation to test + dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles) + + # Compare results + diff_ao_id = (dq_id - dq_ao).abs().max() + diff_op_id = (dq_op - dq_id).abs().max() + diff_op_ao = (dq_op - dq_ao).abs().max() + + # There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize` + # Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast + # conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are + # expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`. + + # Test that the `dequant` kernel gives same results as identity matrix-based dequant + assert diff_op_id == 0 + + # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix + assert diff_op_ao == diff_ao_id + + assert diff_op_ao < 1e-1 + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) +def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size): + n, k = shape + device = "cuda" + + q = torch.randint(0, 16, shape, dtype=torch.int, device=device) + packed_w = torch._convert_weight_to_int4pack(q, inner_k_tiles) + q_groups = k // group_size + scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device) + zeros = torch.randn_like(scales) + scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) + + test_utils = [ + "test_schema", + "test_autograd_registration", + "test_faketensor", + ] + # TODO: Figure out why test fails unless torch >= 2.5 + if TORCH_VERSION_AFTER_2_5: + test_utils.append("test_aot_dispatch_dynamic") + opcheck( + torch.ops.torchao.dequantize_tensor_core_tiled_layout, + (packed_w, scales_and_zeros, group_size, inner_k_tiles), + test_utils=test_utils, + ) + if __name__ == "__main__": - run_tests() + run_tests() \ No newline at end of file diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu new file mode 100644 index 0000000000..652bba5ca6 --- /dev/null +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -0,0 +1,312 @@ +#include +#include +#include +#include +#include +#include + +template +constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { + static_assert(std::is_integral::value && std::is_integral::value, ""); + const uint64_t blocks = a / b + (a % b != 0); + return blocks; +} +constexpr int32_t kWarpSize = 32; + +//Simple data structure to represent 4 pairs of bfloat16s, used for vectorized dequantization +//https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L178-L180 +struct __align__(16) bf16x2x4 { + __nv_bfloat162 vals[4]; +}; + +//Copied from https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L195C1-L241C1 +inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { + bf16x2x4 result; + constexpr int kElements = 8; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = source; + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so + // we must loop. No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); +#pragma unroll + for (int ii = 1; ii < kElements / 2; ++ii) { + i4s >>= 4; // or is it 8? + // (i4s & 0x000f000f) | 0x43004300 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } + + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + +// Finally, we construct the output numbers. +#pragma unroll + for (int ii = 0; ii < kElements / 2; ++ii) { + // Since this section is for Ampere+, we use bf16 fma to do the bias + // subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[ii]) + : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } + + return result; +} +// in size [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2] +// scales_and_zeros size [numQGroups][n][2] +// out size [n][k] +template +__global__ void _dequantize_int4_kernel( + const at::PackedTensorAccessor32 in, + at::PackedTensorAccessor32 out, + at::optional> scales_and_zeros = c10::nullopt) +{ + + constexpr int32_t kNTileSize = 8; + constexpr int32_t kKTileSize = 16; + + auto kOuterTile = blockIdx.x; + auto nTile = blockIdx.y; + auto t = threadIdx.x; + + // n dimension that this lane loads from + auto n0 = nTile * kNTileSize + (t / 4); + + // 8 k-tile values, 4 per m16n8k16 mma.sync operand B + // int32_t ks[8]; + //Only need 4 offsets since TC layout for single tile is 2x2 (2 pairs of 2 contiguous values) + int32_t ks[4]; + + // Store address base offset + auto pOut = &out[n0][0]; + +// Unpack 2 k-tiles at a time since min pack size is InnerKTiles = 2 +#pragma unroll + for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) { + //Tensor-core layout for m16n8k16 is such that each tile has 2 pairs of 2 contiguous values + //Hence, we only need 4 offsets + // Offsets of innerTile0 + auto kBase0 = (kOuterTile * InnerKTiles + innerKTile) * kKTileSize; + ks[0] = kBase0 + (t % 4) * 2; + ks[1] = ks[0] + 8; + + // Offsets of innerTile1 + auto kBase1 = kBase0 + kKTileSize; + ks[2] = kBase1 + (t % 4) * 2; + ks[3] = ks[2] + 8; + + // inner k-tiles unpack two at a time + int32_t pack = in[nTile][kOuterTile][t][innerKTile / 2]; + + if constexpr(kDequant) { + // static_assert(scales_and_zeros.has_value(), "scales_and_zeros must be set when dequantizing"); + static_assert(std::is_same::value, "Out must be BFloat16 when dequantizing"); + // __nv_bfloat16 v[8]; + + // // Extract u4, convert to s4 by subtracting by 2 ** nbits / 2, then convert to bfloat16 + bf16x2x4 v_bf16x2x4 = convert_i4x8_to_bf16x2x4(pack); + + // All b values within a 16x16 tile should fall within the same q group + // Hence we load 1 scale and zero per loop + int qgroup = ks[0] / groupSize; + const __nv_bfloat16 *pSZ = reinterpret_cast(&scales_and_zeros.value()[qgroup][n0][0]); + + // Vectorize scales and zeros + __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); + __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); + + #pragma unroll + for (int i = 0; i < 4; i++) { + reinterpret_cast<__nv_bfloat162*>(&pOut[ks[i]])[0] = __hfma2(v_bf16x2x4.vals[i], scale2, zero2); + } + } + else { + static_assert(std::is_same::value, "Out must be int32_t when unpacking to int"); + int32_t v[8]; + + v[0] = pack & 0x0000000f; + v[2] = (pack >> 4) & 0x0000000f; + v[4] = (pack >> 8) & 0x0000000f; + v[6] = (pack >> 12) & 0x0000000f; + v[1] = (pack >> 16) & 0x0000000f; + v[3] = (pack >> 20) & 0x0000000f; + v[5] = (pack >> 24) & 0x0000000f; + v[7] = (pack >> 28) & 0x0000000f; + int2* v_i32x2 = reinterpret_cast(v); + + #pragma unroll + for (int i = 0; i < 4; ++i) { + reinterpret_cast(&pOut[ks[i]])[0] = v_i32x2[i]; + } + } + } +} + +// output is [n][k] (int32 dtype) +// input is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2] +// scales_and_zeros is [numQGroups][n][2] +// qGroupSize is 32, 64, 128 or 256 +at::Tensor _dequantize_tensor_core_tiled_layout( + const at::Tensor& packed_w, + const at::Tensor& scales_and_zeros, + int64_t group_size, + int64_t innerKTiles) +{ + + constexpr int32_t kNTileSize = 8; + constexpr int32_t kKTileSize = 16; + + c10::cuda::CUDAGuard g(packed_w.device()); + + // packed_w preconditions + TORCH_CHECK(packed_w.dim() == 4); + TORCH_CHECK(packed_w.dtype() == at::kInt); + TORCH_CHECK(packed_w.is_contiguous()); + TORCH_CHECK(packed_w.size(2) == 32); + TORCH_CHECK(packed_w.size(3) == innerKTiles / 2); + TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8); + + auto numQGroups = scales_and_zeros.size(0); + int N = packed_w.size(0) * kNTileSize; + int K = packed_w.size(1) * innerKTiles * kKTileSize; + + // scales_and_zeros preconditions + TORCH_CHECK( + group_size == 32 || group_size == 64 || group_size == 128 || + group_size == 256); + TORCH_CHECK(numQGroups == K / group_size); + TORCH_CHECK(scales_and_zeros.dim() == 3); + TORCH_CHECK(scales_and_zeros.size(1) == N); + TORCH_CHECK(scales_and_zeros.size(2) == 2); + + auto nTiles = divUp(N, kNTileSize); + auto kSuperTiles = divUp(K, innerKTiles * kKTileSize); + auto out = at::empty( + {N, K}, + at::TensorOptions().dtype(at::kBFloat16).device(packed_w.device())); + + auto stream = at::cuda::getCurrentCUDAStream(); + dim3 grid(kSuperTiles, nTiles); + +#define RUN_DEQUANT(QGROUPSIZE) \ + do { \ + switch(innerKTiles) { \ + case 2: \ + _dequantize_int4_kernel<<>>( \ + packed_w.packed_accessor32(), \ + out.packed_accessor32(), \ + scales_and_zeros.packed_accessor32()); \ + break; \ + case 4: \ + _dequantize_int4_kernel<<>>( \ + packed_w.packed_accessor32(), \ + out.packed_accessor32(), \ + scales_and_zeros.packed_accessor32()); \ + break; \ + case 8: \ + _dequantize_int4_kernel<<>>( \ + packed_w.packed_accessor32(), \ + out.packed_accessor32(), \ + scales_and_zeros.packed_accessor32()); \ + break; \ + default: \ + break; \ + } \ + } while(false) + +#define DISPATCH_Q_GROUP() \ + do { \ + switch (group_size) { \ + case 32: \ + RUN_DEQUANT(32); \ + break; \ + case 64: \ + RUN_DEQUANT(64); \ + break; \ + case 128: \ + RUN_DEQUANT(128); \ + break; \ + case 256: \ + RUN_DEQUANT(256); \ + break; \ + default: \ + break; \ + } \ + } while(false) + + DISPATCH_Q_GROUP(); + #undef DISPATCH_Q_GROUP + #undef RUN_DEQUANT + + return out; +} + +// output is [n][k] (int32 dtype) +// input is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2] +at::Tensor _unpack_tensor_core_tiled_layout( + const at::Tensor& packed_w, + int64_t innerKTiles) +{ + + c10::cuda::CUDAGuard g(packed_w.device()); + + TORCH_CHECK(packed_w.dim() == 4); + TORCH_CHECK(packed_w.dtype() == at::kInt); + TORCH_CHECK(packed_w.is_contiguous()); + + TORCH_CHECK(packed_w.size(2) == 32); + TORCH_CHECK(packed_w.size(3) == innerKTiles / 2); + TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8); + + int N = packed_w.size(0) * 8; + int K = packed_w.size(1) * innerKTiles * 16; + + constexpr int32_t kNTileSize = 8; + constexpr int32_t kKTileSize = 16; + + auto nTiles = divUp(N, kNTileSize); + + auto kSuperTiles = divUp(K, innerKTiles * kKTileSize); + + auto out = at::empty( + {N, K}, + at::TensorOptions().dtype(at::kInt).device(packed_w.device())); + + auto stream = at::cuda::getCurrentCUDAStream(); + dim3 grid(kSuperTiles, nTiles); + + if (innerKTiles == 2) { + _dequantize_int4_kernel<<>>( + packed_w.packed_accessor32(), + out.packed_accessor32()); + } + else if (innerKTiles == 4) { + _dequantize_int4_kernel<<>>( + packed_w.packed_accessor32(), + out.packed_accessor32()); + } else if (innerKTiles == 8) { + _dequantize_int4_kernel<<>>( + packed_w.packed_accessor32(), + out.packed_accessor32()); + } + + return out; +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::unpack_tensor_core_tiled_layout", &_unpack_tensor_core_tiled_layout); + m.impl("torchao::dequantize_tensor_core_tiled_layout", &_dequantize_tensor_core_tiled_layout); + +} diff --git a/torchao/csrc/tensor_core_tiled_layout.cpp b/torchao/csrc/tensor_core_tiled_layout.cpp new file mode 100644 index 0000000000..203d5d50c0 --- /dev/null +++ b/torchao/csrc/tensor_core_tiled_layout.cpp @@ -0,0 +1,10 @@ +#include +#include +#include + +TORCH_LIBRARY_FRAGMENT(torchao, m) { + m.impl_abstract_pystub("torchao.ops"); + m.def("unpack_tensor_core_tiled_layout(Tensor packed_w, int inner_k_tiles) -> Tensor"); + m.def("dequantize_tensor_core_tiled_layout(Tensor packed_w, Tensor scales_and_zeros, int group_size, int inner_k_tiles) -> Tensor"); + +} diff --git a/torchao/ops.py b/torchao/ops.py index 3145812a2f..6c7cf03782 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -1,5 +1,6 @@ import torch from torch import Tensor + from torchao.utils import TORCH_VERSION_AFTER_2_4 @@ -53,3 +54,108 @@ def _(EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK = 1): torch._check(OC == _scales.shape[0], lambda: "Dimensions mismatched") return _in_feats.new_empty((BS, OC)) + + + +def unpack_tensor_core_tiled_layout(packed_w: Tensor, inner_k_tiles: int) -> Tensor: + """ + Unpacks weights that were packed with `torch.ops.aten._convert_weight_to_int4pack` to original tensor of shape `N x K`. + + Assumes that the packed weights were generated with `torch.ops.aten._convert_weight_to_int4pack` with `inner_k_tiles = 2 | 4 | 8`" + + Args: + packed_w: torch.tensor: 4D tensor with shape (N / 8) x (K / (inner_k_tiles * 16)) x 32 x inner_k_tiles, dtype is torch.int32 + inner_k_tiles: int + + Returns: + torch.tensor of shape is N x K, dtype is torch.int32 + + """ + return torch.ops.torchao.unpack_tensor_core_tiled_layout.default( + packed_w=packed_w, inner_k_tiles=inner_k_tiles + ) + + +@register_custom_op(f"torchao::unpack_tensor_core_tiled_layout") +def _(packed_w: Tensor, inner_k_tiles: int) -> Tensor: + torch._check( + packed_w.dim() == 4, + lambda: f"packed weight should be a 42d tensor, got {packed_w.dim()}D", + ) + torch._check( + packed_w.dtype is torch.int32, + lambda: f"weight must be INT32, got {packed_w.dtype}", + ) + torch._check( + inner_k_tiles == 2 or inner_k_tiles == 4 or inner_k_tiles == 8, + lambda: "inner_k_tiles must be 2, 4, or 8", + ) + torch._check(packed_w.size(2) == 32, lambda: "packed weight must have 32 at dim 2") + torch._check( + packed_w.size(3) == inner_k_tiles / 2, + lambda: "packed weight must have inner_k_tiles/2 at dim 3", + ) + N = packed_w.size(0) * 8 + K = packed_w.size(1) * inner_k_tiles * 16 + + return torch.empty((N, K), dtype=torch.int32, device=packed_w.device) + +def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int) -> Tensor: + """ + Dequantizes by: + - Unpacking weights that were packed with `torch.ops.aten._convert_weight_to_int4pack` to original tensor of shape `N x K` + - Upcasting to bfloat16 + - Dequantizing with the scales_and_zeros that were packed with `torchao.quantization.utils.pack_tinygemm_scales_and_zeros` + + Assumes: + - packed weights were generated with `torch.ops.aten._convert_weight_to_int4pack` with `inner_k_tiles = 2 | 4 | 8`" + - packed scales_and_zeros were generated with `torchao.quantization.utils.pack_tinygemm_scales_and_zeros` + - qGroupSize is 32 | 64 | 128 | 256 + + Args: + packed_w: torch.tensor: 4D tensor with shape `(N / 8) x (K / (inner_k_tiles * 16)) x 32 x inner_k_tiles / 2`, dtype is torch.int32 + scales_and_zeros: torch.tensor: 3D tensor with shape `numQGroups x N x 2`, dtype is torch.bfloat16 where numQGroups is K / qGroupSize + qGroupSize: int + inner_k_tiles: int + + Returns: + torch.tensor of shape is N x K, dtype is torch.bfloat16 + + """ + return torch.ops.torchao.dequantize_tensor_core_tiled_layout.default( + packed_w, scales_and_zeros, group_size, inner_k_tiles + ) + + +@register_custom_op(f"torchao::dequantize_tensor_core_tiled_layout") +def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int) -> Tensor: + # packed_w preconditions + torch._check( + packed_w.dim() == 4, + lambda: f"packed weight should be a 4d tensor, got {packed_w.dim()}D", + ) + torch._check( + packed_w.dtype is torch.int32, + lambda: f"weight must be INT32, got {packed_w.dtype}", + ) + torch._check( + inner_k_tiles == 2 or inner_k_tiles == 4 or inner_k_tiles == 8, + lambda: "inner_k_tiles must be 2, 4, or 8", + ) + torch._check(packed_w.size(2) == 32, lambda: "packed weight must have 32 at dim 2") + torch._check( + packed_w.size(3) == inner_k_tiles / 2, + lambda: "packed weight must have inner_k_tiles/2 at dim 3", + ) + N = packed_w.size(0) * 8 + K = packed_w.size(1) * inner_k_tiles * 16 + + # scales_and_zeros preconditions + torch._check(scales_and_zeros.dtype is torch.bfloat16, lambda: "scales_and_zeros must be bfloat16") + torch._check(scales_and_zeros.dim() == 3, lambda: "scales_and_zeros must be 3D, got {scales_and_zeros.dim()}") + torch._check(group_size == 32 or group_size == 64 or group_size == 128 or group_size == 256, lambda: "qGroupSize must be 32, 64, 128, or 256") + torch._check(scales_and_zeros.size(0) == K // group_size, lambda: "scales_and_zeros must have K // qGroupSize at dim 0") + torch._check(scales_and_zeros.size(1) == N, lambda: "scales_and_zeros must have N at dim 1") + torch._check(scales_and_zeros.size(2) == 2, lambda: "scales_and_zeros must have 2 at dim 2") + + return torch.empty((N, K), dtype=torch.bfloat16, device=packed_w.device) \ No newline at end of file