diff --git a/test/dtypes/test_bitnet.py b/test/dtypes/test_bitnet.py new file mode 100644 index 0000000000..1abdd0c1ed --- /dev/null +++ b/test/dtypes/test_bitnet.py @@ -0,0 +1,75 @@ +import pytest +import torch +import torch.nn as nn +from torchao.prototype.dtypes import BitnetTensor +from torchao.prototype.dtypes.uint2 import unpack_uint2 +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter +from torchao.utils import TORCH_VERSION_AFTER_2_4 + +if not TORCH_VERSION_AFTER_2_4: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + +@pytest.fixture(autouse=True) +def run_before_and_after_tests(): + # source: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test # noqa: E501 + + # setup (currently do nothing) + + # tests will run here + yield + + # teardown + # avoid dynamo cache limit issues + torch._dynamo.reset() + +@pytest.fixture +def bitnet_tensor(): + input_tensor = torch.randint(0, 15, (4,4), dtype=torch.uint8) + return BitnetTensor.from_unpacked(input_tensor) + +def test_copy(bitnet_tensor): + copied_tensor = bitnet_tensor.clone() + assert torch.equal(bitnet_tensor.elem, copied_tensor.elem) + +def test_transpose(bitnet_tensor): + transposed_tensor = bitnet_tensor.t() + expected_tensor = unpack_uint2(bitnet_tensor.elem).t() + assert torch.equal(unpack_uint2(transposed_tensor.elem), expected_tensor) + +def test_multiply(bitnet_tensor): + w_t = torch.randint(0, 15, (4, 16), dtype=torch.uint8) + w = BitnetTensor.from_unpacked(w_t) + y = torch.addmm(torch.Tensor([1]), bitnet_tensor, w) + +@pytest.mark.parametrize("dtype", [torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64]) +def test_conversion(bitnet_tensor, dtype): + converted_tensor = bitnet_tensor.to(dtype) + expected_tensor = unpack_uint2(bitnet_tensor.elem).to(dtype) + assert torch.allclose(converted_tensor, expected_tensor, atol=1e-5) + +def _apply_weight_only_uint2_quant(model): + def fn(mod): + mod.weight = torch.nn.Parameter(BitnetTensor.from_float(mod.weight), requires_grad=False) + return mod + + _replace_with_custom_fn_if_matches_filter( + model, + lambda mod: fn(mod), + lambda mod, fqn: isinstance(mod, torch.nn.Linear), + ) + +@pytest.mark.parametrize("input_shape", [[2, 4], [5, 5, 5, 4], [1, 4, 4]]) +def test_uint2_quant(input_shape): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + x = torch.randn(*input_shape).to(device) + m = nn.Sequential(nn.Linear(4, 16)).to(device) + y_ref = m(x) + _apply_weight_only_uint2_quant(m) + y_wo = m(x) + assert y_ref.shape == y_wo.shape + y_compiled = torch.compile(m, fullgraph=True)(x) + + +if __name__ == "__main__": + pytest.main(__file__) + diff --git a/test/dtypes/test_uint2.py b/test/dtypes/test_uint2.py new file mode 100644 index 0000000000..4cdfd88baf --- /dev/null +++ b/test/dtypes/test_uint2.py @@ -0,0 +1,33 @@ +import pytest +import torch +import torch.nn as nn +from torchao.prototype.dtypes import UInt2Tensor +from torchao.prototype.dtypes.uint2 import unpack_uint2 +from torchao.utils import TORCH_VERSION_AFTER_2_4 + +if not TORCH_VERSION_AFTER_2_4: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + +@pytest.fixture +def uint2_tensor(): + input_tensor = torch.randint(0, 15, (4,4), dtype = torch.uint8) + return UInt2Tensor(input_tensor) + +def test_copy(uint2_tensor): + copied_tensor = uint2_tensor.clone() + assert torch.equal(uint2_tensor.elem, copied_tensor.elem) + +def test_transpose(uint2_tensor): + transposed_tensor = uint2_tensor.t() + expected_tensor = unpack_uint2(uint2_tensor.elem).t() + assert torch.equal(unpack_uint2(transposed_tensor.elem), expected_tensor) + +@pytest.mark.parametrize("dtype", [torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64]) +def test_conversion(uint2_tensor, dtype): + converted_tensor = uint2_tensor.to(dtype) + expected_tensor = unpack_uint2(uint2_tensor.elem).to(dtype) + assert torch.allclose(converted_tensor, expected_tensor, atol=1e-5) + +if __name__ == '__main__': + pytest.main(__file__) + diff --git a/test/prototype/test_bitpacking_gen.py b/test/prototype/test_bitpacking_gen.py new file mode 100644 index 0000000000..b729c6250d --- /dev/null +++ b/test/prototype/test_bitpacking_gen.py @@ -0,0 +1,26 @@ +import pytest +import torch + +from torchao.prototype.dtypes.uintgen import ( + pack_uint2, unpack_uint2, pack_uint3, unpack_uint3, pack_uint4, unpack_uint4, + pack_uint5, unpack_uint5, pack_uint6, unpack_uint6, pack_uint7, unpack_uint7 +) + +@pytest.mark.parametrize("pack_fn, unpack_fn, bit_count", [ + (pack_uint2, unpack_uint2, 2), + (pack_uint3, unpack_uint3, 3), + (pack_uint4, unpack_uint4, 4), + (pack_uint5, unpack_uint5, 5), + (pack_uint6, unpack_uint6, 6), + (pack_uint7, unpack_uint7, 7), +]) +def test_uint_packing(pack_fn, unpack_fn, bit_count): + x = torch.arange(0, 256, dtype=torch.uint8) + y = pack_fn(x) + z = unpack_fn(y) + k = z.view(-1, 2 ** bit_count) + check = torch.arange(0, 2 ** bit_count, dtype=torch.uint8).repeat(k.size(0), 1) + assert torch.all(k == check), f"Failed for {bit_count}-bit packing" + +if __name__ == "__main__": + pytest.main(__file__) \ No newline at end of file diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 87d129bc1b..c5264048c7 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,4 +1,5 @@ from .nf4tensor import NF4Tensor, to_nf4 +# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor from .uint4 import UInt4Tensor from .aqt import AffineQuantizedTensor, to_affine_quantized diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py new file mode 100644 index 0000000000..9f16283ac5 --- /dev/null +++ b/torchao/prototype/dtypes/__init__.py @@ -0,0 +1,9 @@ + +from .uint2 import UInt2Tensor +from .bitnet import BitnetTensor + +__all__ = [ + "BitnetTensor", + "UInt2Tensor", +] + diff --git a/torchao/prototype/dtypes/bitnet.py b/torchao/prototype/dtypes/bitnet.py new file mode 100644 index 0000000000..61fb159927 --- /dev/null +++ b/torchao/prototype/dtypes/bitnet.py @@ -0,0 +1,161 @@ +import torch +from torchao.prototype.dtypes.uint2 import UInt2Tensor, unpack_uint2, pack_uint2 + +BITNET_OPS_TABLE = {} + +def implements(aten_ops): + def decorator(fn): + for op in aten_ops: + BITNET_OPS_TABLE[op] = fn + return fn + return decorator + +def _quantize_int2(x: torch.Tensor) -> torch.Tensor: + # Quantize the input tensor to int2 + quant = x.sign() + 1 + quant = BitnetTensor.from_unpacked(quant.to(torch.uint8)) + return quant + +class BitnetTensor(UInt2Tensor): + def __new__(cls, input_tensor: torch.Tensor, **kwargs): + return super(BitnetTensor, cls).__new__(cls, input_tensor, **kwargs) + + def __init__(self, input_tensor: torch.Tensor, **kwargs): + super(BitnetTensor, self).__init__(input_tensor, **kwargs) + + @staticmethod + def __tensor_unflatten__(flattened, *meta): + # TODO - meta is not None, is it ok? + elem = flattened["elem"] + return BitnetTensor(elem) + + @classmethod + def from_unpacked(cls, unpacked: torch.Tensor) -> "BitnetTensor": + return cls(pack_uint2(unpacked)) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + def allowed_subclasses(type): + return ( + issubclass(cls, type) or + issubclass(torch._subclasses.fake_tensor.FakeTensor, type) or + issubclass(torch._subclasses.functional_tensor.FunctionalTensor, type) + ) + + if not all(allowed_subclasses(t) for t in types): + return NotImplemented("Bitnet, Up to the next one to handle") + + if func in BITNET_OPS_TABLE: + return BITNET_OPS_TABLE[func](func, args, kwargs) + raise NotImplementedError(f"Bitnet dispatch: attempting to run {func}, this is not supported") + + @classmethod + def from_float(cls, w: torch.Tensor): + w_intq = _quantize_int2(w) + w_int2 = w_intq.to(device=w.device) + return w_int2 + + def clone(self): + return BitnetTensor(self.elem.clone()) + + def copy_(self, src): + self.elem.copy_(src.elem) + return self + + def tolist(self): + data = unpack_uint2(self.elem).tolist() + return data + + def __repr__(self): + try: + data = unpack_uint2(self.elem).tolist() + except AssertionError: + data = f"Tensor of shape {self.shape} and dtype {self.elem.dtype}" + return f"BitnetTensor({data}, dtype={self.elem.dtype})" + + def to(self, *args, **kwargs): + if len(args) == 1 and isinstance(args[0], torch.dtype): + dtype = args[0] + if dtype == torch.int8: + return unpack_uint2(self.elem).view(self.shape).view(torch.int8) + elif dtype in (torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64): + return unpack_uint2(self.elem).to(torch.int8).to(dtype) + elif dtype == torch.uint8: + return unpack_uint2(self.elem).view(torch.uint8) + elif isinstance(self, BitnetTensor): + return self + if 'device' in kwargs: + device = kwargs['device'] + return BitnetTensor(self.elem.to(device=device)) + + return super().to(*args, **kwargs) + +@implements([torch.ops.aten.mm.default]) +def mm(func, args, kwargs): + x, weight = args + if isinstance(x, BitnetTensor): + x = unpack_uint2(x.elem).to(torch.float32) + if isinstance(weight, BitnetTensor): + weight = unpack_uint2(weight.elem).to(torch.float32) + y = torch.mm(x, weight) + return y + +@implements([torch.ops.aten.addmm.default]) +def addmm(func, args, kwargs): + bias, x, weight = args + if isinstance(x, BitnetTensor): + x = unpack_uint2(x.elem).to(torch.float32) + if isinstance(weight, BitnetTensor): + weight = unpack_uint2(weight.elem).to(torch.float32) + if bias is not None: + bias = bias.to(torch.float32) + y = torch.addmm(bias, x, weight) + return y + +@implements([torch.ops.aten.t.default]) +def t(func, args, kwargs): + (tensor,) = args + unpacked = unpack_uint2(tensor.elem).to(tensor.device) + transposed = unpacked.t() + return BitnetTensor(pack_uint2(transposed)) + +@implements([torch.ops.aten.detach.default]) +def detach(func, args, kwargs): + (tensor,) = args + return tensor + +@implements([torch.ops.aten.to.dtype]) +def to_dtype(func, args, kwargs): + (tensor, dtype) = args + if dtype == torch.int8: + return unpack_uint2(tensor.elem).view(torch.uint8) - 1 + elif dtype in (torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64): + return unpack_uint2(tensor.elem).to(torch.int8).to(dtype) + elif dtype == torch.uint8: + return unpack_uint2(tensor.elem).view(torch.uint8) + elif isinstance(tensor, BitnetTensor): + return tensor.elem + raise NotImplementedError(f"to {dtype} not supported") + +@implements([torch.ops.aten._to_copy.default]) +def _to_copy(func, args, kwargs): + (tensor,) = args + dtype = kwargs["dtype"] + if dtype == torch.int8: + return BitnetTensor(unpack_uint2(tensor).view(tensor.shape).view(torch.int8) - 1) + elif dtype in (torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64): + return BitnetTensor(tensor.to(torch.int8).to(dtype)) + elif isinstance(tensor, BitnetTensor): + return BitnetTensor(tensor) + raise NotImplementedError(f"to {dtype} not supported") + +@implements([torch.ops.aten.clone.default]) +def clone(func, args, kwargs): + (tensor,) = args + return tensor.clone() + +@implements([torch.ops.aten.allclose.default]) +def allclose(func, args, kwargs): + (a, b) = args + return torch.allclose(a.elem, b.elem, **kwargs) + diff --git a/torchao/prototype/dtypes/uint2.py b/torchao/prototype/dtypes/uint2.py new file mode 100644 index 0000000000..c0e88e94d2 --- /dev/null +++ b/torchao/prototype/dtypes/uint2.py @@ -0,0 +1,239 @@ +import torch +import torch._prims_common as utils +from dataclasses import dataclass +from typing import Dict, Any, Tuple + +UINT2_OPS_TABLE: Dict[Any, Any] = {} + +def fill_defaults(args, n, defaults_tail): + if n - len(defaults_tail) > len(args): + raise RuntimeError("not enough defaults to fill arguments") + r = list(args) + for i in range(len(args), n): + r.append(defaults_tail[i - n + len(defaults_tail)]) + return r + +def implements(aten_ops): + def decorator(fn): + for op in aten_ops: + UINT2_OPS_TABLE[op] = fn + return fn + return decorator + +def down_size(size): + assert size[-1] % 4 == 0, f"{size} last dim not divisible by 4" + return (*size[:-1], size[-1] // 4) + +def up_size(size): + return (*size[:-1], size[-1] * 4) + +def unpack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: + shape = uint8_data.shape + uint8_data = uint8_data.to(torch.uint8) + first_elements = ((uint8_data >> 6) & 0b11) + second_elements = ((uint8_data >> 4) & 0b11) + third_elements = ((uint8_data >> 2) & 0b11) + fourth_elements = (uint8_data & 0b11) + return torch.stack((first_elements, second_elements, third_elements, fourth_elements), dim=-1).view(up_size(shape)) + +def pack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: + shape = uint8_data.shape + assert shape[-1] % 4 == 0, f"{shape}, last dim not divisible by 4" + uint8_data = uint8_data.contiguous().view(-1) + packed_data = (uint8_data[::4] << 6 | uint8_data[1::4] << 4 | uint8_data[2::4] << 2 | uint8_data[3::4]).view(down_size(shape)) + return packed_data + +@dataclass +class SubclassTensorArgs: + original_shape: torch.Size + original_strides: Tuple + storage_offset: int + dtype: torch.dtype + device: torch.device + requires_grad: bool + +class UInt2Tensor(torch.Tensor): + def __new__(cls, input_tensor: torch.Tensor): + assert input_tensor.dtype == torch.uint8 + tensor_meta = SubclassTensorArgs( + input_tensor.size(), + input_tensor.stride(), + input_tensor.storage_offset(), + cls, + input_tensor.device, + input_tensor.requires_grad + ) + uint2i_tensor = torch.Tensor._make_wrapper_subclass( + cls, + up_size(tensor_meta.original_shape), + tensor_meta.original_strides, + tensor_meta.storage_offset, + dtype=torch.uint8, #Not sure if this is correct + device=tensor_meta.device, + requires_grad=tensor_meta.requires_grad + ) + return uint2i_tensor + + def __init__(self, input_tensor: torch.Tensor, **kwargs): + self.elem = input_tensor + + @classmethod + def from_packed(cls, unpacked): + return UInt2Tensor(pack_uint2(unpacked)) + + def tolist(self): + return unpack_uint2(self.elem).tolist() + + def __tensor_flatten__(self): + return ["elem"], None + + @staticmethod + def __tensor_unflatten__(flattened, meta): + assert meta is None + elem = flattened["elem"] + return UInt2Tensor(elem) + + def __hash__(self): + return hash(self.elem) + + def __eq__(self, other): + return torch.equal(self.elem, other.elem) + + def __repr__(self): + data = unpack_uint2(self.elem).tolist() + return f"UInt2Tensor({data}, dtype=torch.uint2)" + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + def allowed_subclasses(type): + return ( + issubclass(cls, type) or + issubclass(torch._subclasses.fake_tensor.FakeTensor, type) or + issubclass(torch._subclasses.functional_tensor.FunctionalTensor, type) + ) + + if not all(allowed_subclasses(t) for t in types): + return NotImplemented("Up to the next one to handle") + + if func in UINT2_OPS_TABLE: + return UINT2_OPS_TABLE[func](func, args, kwargs) + raise NotImplementedError(f"UINT2 dispatch: attempting to run {func}, this is not supported") + +@implements([torch.ops.aten.view.default]) +def uint2_view(func, args, kwargs): + tensor, size = args + size = utils.infer_size(size, tensor.numel()) + assert not kwargs + dsize = down_size(size) + reshaped_elem = tensor.elem.view(dsize) + return UInt2Tensor(reshaped_elem) + +@implements([torch.ops.aten.view.dtype]) +def view_dtype(func, args, kwargs): + tensor, dtype = args + if dtype is torch.uint8: + return unpack_uint2(tensor.elem).to(torch.uint8) + raise NotImplementedError(f"view {dtype} not supported") + +@implements([torch.ops.aten.clone.default]) +def clone(func, args, kwargs): + tensor = args[0] + return UInt2Tensor(tensor.elem.clone()) + +@implements([torch.ops.aten._unsafe_view.default]) +def unsafe_view(func, args, kwargs): + tensor, size = args + size = utils.infer_size(size, tensor.numel()) + assert not kwargs + dsize = down_size(size) + reshaped_elem = tensor.elem.view(dsize) + return UInt2Tensor(reshaped_elem) + +@implements([torch.ops.aten.unbind.int]) +def unbind(func, args, kwargs): + tensor, dim = fill_defaults(args, 2, [0]) + if dim != tensor.dim() - 1: + raise NotImplementedError(f"unbind dim={dim}") + else: + x = tensor.elem.to(torch.uint8).unbind(dim) + return x + +@implements([torch.ops.aten._to_copy.default]) +def to_copy(func, args, kwargs): + (tensor,) = args + dtype = kwargs["dtype"] + if dtype == torch.uint8: + return unpack_uint2(tensor.elem).view(tensor.shape).view(torch.uint8) + elif dtype in (torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64): + return tensor.to(torch.uint8).to(dtype) + elif isinstance(tensor, UInt2Tensor): + return tensor + raise NotImplementedError(f"to_copy {dtype} not supported") + +@implements([torch.ops.aten.select.int]) +def select(func, args, kwargs): + tensor, dim, index = args + if dim != tensor.dim() - 1: + selected_elem = tensor.elem.select(dim, index) + return UInt2Tensor(selected_elem) + else: + raise NotImplementedError(f"select dim={dim}") + +@implements([torch.ops.aten.reshape.default]) +def reshape(func, args, kwargs): + tensor, size = args + size = utils.infer_size(size, tensor.numel()) + assert not kwargs + dsize = down_size(size) + reshaped_elem = tensor.elem.view(dsize) + return UInt2Tensor(reshaped_elem) + +def slice_tensor(func, args, kwargs): + tensor, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == tensor.dim() - 1: + if step != 1: + raise NotImplementedError(f"slice step={step}") + assert start % 4 == 0, start + assert end is None or end % 4 == 0, end + end = end if end is not None else tensor.shape[dim] + sliced_elem = tensor.elem[..., start // 4 : end // 4 : step] + return UInt2Tensor(sliced_elem) + else: + sliced_elem = tensor.elem[..., start:end:step] + return UInt2Tensor(sliced_elem) + +@implements([torch.ops.aten.equal.default]) +def equal(func, args, kwargs): + tensor, other = args + return torch.equal(tensor.elem, other.elem) + +@implements([torch.ops.aten.detach.default]) +def detach(func, args, kwargs): + (tensor,) = args + detached_elem = tensor.elem.detach() + return UInt2Tensor(detached_elem) + +@implements([torch.ops.aten.to.dtype]) +def to_dtype(func, args, kwargs): + (tensor, dtype) = args + if dtype == torch.uint8: + return unpack_uint2(tensor.elem).view(torch.uint8) + elif dtype in (torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64): + return unpack_uint2(tensor.elem).to(torch.uint8).to(dtype) + elif isinstance(tensor, UInt2Tensor): + return tensor.elem + + raise NotImplementedError(f"to {dtype} not supported") + +@implements([torch.ops.aten.t.default]) +def t(func, args, kwargs): + (tensor,) = args + unpacked = unpack_uint2(tensor.elem).to(tensor.device) + transposed = unpacked.t() + return UInt2Tensor(pack_uint2(transposed)) + +@implements([torch.ops.aten.allclose.default]) +def allclose(func, args, kwargs): + tensor, other = args + return torch.allclose(tensor.elem, other.elem) diff --git a/torchao/prototype/dtypes/uintgen.py b/torchao/prototype/dtypes/uintgen.py new file mode 100644 index 0000000000..1312816f1e --- /dev/null +++ b/torchao/prototype/dtypes/uintgen.py @@ -0,0 +1,344 @@ +import torch + +""" +Contains generic functions to pack and unpack uintx (2-7) tensors into uint8 tensors. +""" + +def down_size_uint2(size): + assert size[-1] % 4 == 0, f"{size} last dim not divisible by four" + return (*size[:-1], size[-1] // 4) + + +def up_size_uint2(size): + return (*size[:-1], size[-1] * 4) + + +def unpack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: + # since we are using uint8 we will decode 4 entries per byte + shape = uint8_data.shape + uint8_data = uint8_data.to(torch.uint8) + first_elements = (uint8_data >> 6) & 0b11 + second_elements = (uint8_data >> 4) & 0b11 + third_elements = (uint8_data >> 2) & 0b11 + fourth_elements = uint8_data & 0b11 + return torch.stack( + (first_elements, second_elements, third_elements, fourth_elements), dim=-1 + ).view(up_size_uint2(shape)) + + +def pack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: + """pack lowest 2 bits of 2 uint8 -> 1 uint8""" + shape = uint8_data.shape + assert shape[-1] % 4 == 0 + uint8_data = uint8_data.contiguous().view(-1) + packed_data = ( + (uint8_data[::4] & 0b11) << 6 + | (uint8_data[1::4] & 0b11) << 4 + | (uint8_data[2::4] & 0b11) << 2 + | (uint8_data[3::4] & 0b11) + ).view(down_size_uint2(shape)) + return packed_data + + +def down_size_uint3(size): + assert size[-1] % 8 == 0, f"{size} last dim not divisible by eight" + return (*size[:-1], size[-1] // 8 * 3) + + +def up_size_uint3(size): + assert size[-1] % 3 == 0, f"{size} last dim not divisible by three" + return (*size[:-1], size[-1] // 3 * 8) + + +def unpack_uint3(uint8_data: torch.Tensor) -> torch.Tensor: + """ + 3 -> 8 + 01234567|01234567|01234567 + AAABBBCC|CDDDEEEF|FFGGGHHH + """ + shape = uint8_data.shape + uint8_data = uint8_data.to(torch.uint8) + + return torch.stack( + ( + (uint8_data[::3] >> 5) & 0b111, + (uint8_data[::3] >> 2) & 0b111, + (uint8_data[::3] & 0b11) << 1 | (uint8_data[1::3] >> 7) & 0b1, + (uint8_data[1::3] >> 4) & 0b111, + (uint8_data[1::3] >> 1) & 0b111, + (uint8_data[1::3] & 0b1) << 2 | (uint8_data[2::3] >> 6) & 0b11, + (uint8_data[2::3] >> 3) & 0b111, + uint8_data[2::3] & 0b111, + ), + dim=-1, + ).view(up_size_uint3(shape)) + + +def pack_uint3(uint8_data: torch.Tensor) -> torch.Tensor: + """ + 8 -> 3 + 01234567|01234567|01234567 + AAABBBCC|CDDDEEEF|FFGGGHHH + """ + + shape = uint8_data.shape + assert shape[-1] % 8 == 0 + uint8_data = uint8_data.contiguous().view(-1) + + packed_data = torch.stack( + ( + ((uint8_data[::8] & 0b111) << 5 | (uint8_data[1::8] & 0b111) << 2 | (uint8_data[2::8] & 0b111) >> 1), + ((uint8_data[2::8] & 0b1) << 7 | (uint8_data[3::8] & 0b111) << 4 | (uint8_data[4::8] & 0b111) << 1 | ((uint8_data[5::8] >> 2) & 1)), + ((uint8_data[5::8] & 0b11) << 6 | (uint8_data[6::8] & 0b111) << 3 | (uint8_data[7::8] & 0b111)), + ), + dim=-1 + ).view(down_size_uint3(shape)) + + return packed_data + + +def down_size_uint4(size): + assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" + return (*size[:-1], size[-1] // 2) + + +def up_size_uint4(size): + return (*size[:-1], size[-1] * 2) + + +def unpack_uint4(uint8_data: torch.Tensor) -> torch.Tensor: + shape = uint8_data.shape + uint8_data = uint8_data.to(torch.uint8) + first_elements = (uint8_data >> 4) & 0b1111 + second_elements = uint8_data & 0b1111 + return torch.stack((first_elements, second_elements), dim=-1).view( + up_size_uint4(shape) + ) + + +def pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor: + shape = uint8_data.shape + assert shape[-1] % 2 == 0 + uint8_data = uint8_data.contiguous().view(-1) + packed_data = (uint8_data[::2] << 4 | (uint8_data[1::2] & 0b1111)).view( + down_size_uint4(shape) + ) + return packed_data + + +def down_size_uint5(size): + assert size[-1] % 8 == 0, f"{size} last dim not divisible by 8" + return (*size[:-1], size[-1] // 8 * 5) + + +def up_size_uint5(size): + assert size[-1] % 5 == 0, f"{size} last dim not divisible by 5" + return (*size[:-1], size[-1] // 5 * 8) + + +def pack_uint5(uint8_data: torch.Tensor) -> torch.Tensor: + """Pack the 5 lowest bits of 8 input bytes into 5 bytes + + 8 -> 5 + 01234567|01234567|01234567|01234567|01234567 + AAAAABBB|BBCCCCCD|DDDDEEEE|EFFFFFGG|GGGHHHHH + + The packing pattern: + - First byte: (A0 A1 A2 A3 A4 B0 B1 B2) + - Second byte: (B3 B4 C0 C1 C2 C3 C4 D0) + - Third byte: (D1 D2 D3 D4 E0 E1 E2 E3) + - Fourth byte: (E4 F0 F1 F2 F3 F4 G0 G1) + - Fifth byte: (G2 G3 G4 H0 H1 H2 H3 H4) + """ + shape = uint8_data.shape + assert ( + shape[-1] % 8 == 0 + ), f"Input last dimension should be divisible by 8, but got {shape[-1]}" + + uint8_data = uint8_data.contiguous().view(-1, 8) + + packed_data = torch.stack( + ( + ((uint8_data[:, 0] & 0b00011111) << 3) | ((uint8_data[:, 1] & 0b00011100) >> 2), + ((uint8_data[:, 1] & 0b00000011) << 6) | ((uint8_data[:, 2] & 0b00011111) << 1) | ((uint8_data[:, 3] & 0b10000) >> 4), + ((uint8_data[:, 3] & 0b00001111) << 4) | ((uint8_data[:, 4] & 0b00011110) >> 1), + ((uint8_data[:, 4] & 0b00000001) << 7) | ((uint8_data[:, 5] & 0b00011111) << 2) | ((uint8_data[:, 6] & 0b0011000) >> 3), + ((uint8_data[:, 6] & 0b00000111) << 5) | (uint8_data[:, 7] & 0b00011111), + ), + dim=-1, + ).view(down_size_uint5(shape)) + + return packed_data + + +def unpack_uint5(packed_data: torch.Tensor) -> torch.Tensor: + """Unpack the 5 bytes into the 5 lowest bits of 8 bytes + 01234567|01234567|01234567|01234567|01234567 + AAAAABBB|BBCCCCCD|DDDDEEEE|EFFFFFGG|GGGHHHHH + """ + shape = packed_data.shape + assert ( + shape[-1] % 5 == 0 + ), f"Input last dimension should be divisible by 5, but got {shape[-1]}" + + packed_data = packed_data.contiguous().view(-1, 5) + + unpacked_data = torch.stack( + ( + ((packed_data[:, 0] >> 3) & 0b00011111), + ((packed_data[:, 0] & 0b00000111) << 2) | ((packed_data[:, 1] >> 6) & 0b00000011), + ((packed_data[:, 1] >> 1) & 0b00011111), + ((packed_data[:, 1] & 0b00000001) << 4) | ((packed_data[:, 2] >> 4) & 0b00001111), + ((packed_data[:, 2] & 0b00001111) << 1) | ((packed_data[:, 3] >> 7) & 0b00000001), + ((packed_data[:, 3] >> 2) & 0b00011111), + ((packed_data[:, 3] & 0b00000011) << 3) | ((packed_data[:, 4] >> 5) & 0b00000111), + packed_data[:, 4] & 0b00011111, + ), + dim=-1, + ).view(up_size_uint5(shape)) + + return unpacked_data + + +def down_size_uint6(size): + assert size[-1] % 4 == 0, f"{size} last dim not divisible by four" + return (*size[:-1], size[-1] // 4 * 3) + + +def up_size_uint6(size): + assert size[-1] % 3 == 0, f"{size} last dim not divisible by three" + return (*size[:-1], size[-1] // 3 * 4) + + +def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: + """Pack the 6 lowest bits of 4 input bytes into 3 bytes + + 4 -> 3 + 01234567|01234567|01234567 + AAAAAABB|BBBBCCCC|CCDDDDDD + + The packing pattern: + - First byte: (A0 A1 A2 A3 A4 A5 B0 B1) + - Second byte: (B2 B3 B4 B5 C0 C1 C2 C3) + - Third byte: (C4 C5 D0 D1 D2 D3 D4 D5) + """ + shape = uint8_data.shape + assert ( + shape[-1] % 4 == 0 + ), f"Input last dimension should be divisible by 4, but got {shape[-1]}" + + uint8_data = uint8_data.contiguous().view(-1, 4) + + packed_data = torch.stack( + ( + ((uint8_data[:, 0] & 0b00111111) << 2) | ((uint8_data[:, 1] >> 4) & 0b00000011), + ((uint8_data[:, 1] & 0b00001111) << 4) | ((uint8_data[:, 2] >> 2) & 0b00001111), + ((uint8_data[:, 2] & 0b00000011) << 6) | (uint8_data[:, 3] & 0b00111111), + ), + dim=-1, + ).view(down_size_uint6(shape)) + + return packed_data + + +def unpack_uint6(packed_data: torch.Tensor) -> torch.Tensor: + """Unpack the 3 bytes into the 6 lowest bits of 4 outputs + 01234567|01234567|01234567 + AAAAAABB|BBBBCCCC|CCDDDDDD + """ + shape = packed_data.shape + assert ( + shape[-1] % 3 == 0 + ), f"Input last dimension should be divisible by 3, but got {shape[-1]}" + + packed_data = packed_data.contiguous().view(-1, 3) + + unpacked_data = torch.stack( + ( + (packed_data[:, 0] >> 2) & 0b00111111, + ((packed_data[:, 0] & 0b00000011) << 4) | ((packed_data[:, 1] >> 4) & 0b00001111), + ((packed_data[:, 1] & 0b00001111) << 2) | ((packed_data[:, 2] >> 6) & 0b00000011), + packed_data[:, 2] & 0b00111111, + ), + dim=-1, + ).view(up_size_uint6(shape)) + + return unpacked_data + + +def down_size_uint7(size): + assert size[-1] % 8 == 0, f"{size} last dim not divisible by 8" + return (*size[:-1], size[-1] // 8 * 7) + + +def up_size_uint7(size): + assert size[-1] % 7 == 0, f"{size} last dim not divisible by 7" + return (*size[:-1], size[-1] // 7 * 8) + + +def pack_uint7(uint8_data: torch.Tensor) -> torch.Tensor: + """Pack the 7 lowest bits of 8 input bytes into 7 bytes + + 8 -> 7 + 01234567|01234567|01234567|01234567|01234567|01234567|01234567 + AAAAAAAB|BBBBBBCC|CCCCCDDD|DDDDEEEE|EEEFFFFF|FFGGGGGG|GHHHHHHH + + The packing pattern: + - First byte: (A0 A1 A2 A3 A4 A5 A6 B0) + - Second byte: (B1 B2 B3 B4 B5 B6 C0 C1) + - Third byte: (C2 C3 C4 C5 C6 D0 D1 D2) + - Fourth byte: (D3 D4 D5 D6 E0 E1 E2 E3) + - Fifth byte: (E4 E5 E6 F0 F1 F2 F3 F4) + - Sixth byte: (F5 F6 G0 G1 G2 G3 G4 G5) + - Seventh byte:(G6 H0 H1 H2 H3 H4 H5 H6) + """ + shape = uint8_data.shape + assert ( + shape[-1] % 8 == 0 + ), f"Input last dimension should be divisible by 8, but got {shape[-1]}" + + uint8_data = uint8_data.contiguous().view(-1, 8) + + packed_data = torch.stack( + ( + ((uint8_data[:, 0] & 0b01111111) << 1) | ((uint8_data[:, 1] >> 6) & 0b00000001), + ((uint8_data[:, 1] & 0b00111111) << 2) | ((uint8_data[:, 2] >> 5) & 0b00000011), + ((uint8_data[:, 2] & 0b00011111) << 3) | ((uint8_data[:, 3] >> 4) & 0b00000111), + ((uint8_data[:, 3] & 0b00001111) << 4) | ((uint8_data[:, 4] >> 3) & 0b00001111), + ((uint8_data[:, 4] & 0b00000111) << 5) | ((uint8_data[:, 5] >> 2) & 0b00011111), + ((uint8_data[:, 5] & 0b00000011) << 6) | ((uint8_data[:, 6] >> 1) & 0b00111111), + ((uint8_data[:, 6] & 0b00000001) << 7) | ((uint8_data[:, 7] >> 0) & 0b01111111), + ), + dim=-1, + ).view(down_size_uint7(shape)) + + return packed_data + + +def unpack_uint7(packed_data: torch.Tensor) -> torch.Tensor: + """Unpack the 7 bytes into the 7 lowest bits of 8 bytes + 01234567|01234567|01234567|01234567|01234567|01234567|01234567 + AAAAAAAB|BBBBBBCC|CCCCCDDD|DDDDEEEE|EEEFFFFF|FFGGGGGG|GHHHHHHH + """ + shape = packed_data.shape + assert ( + shape[-1] % 7 == 0 + ), f"Input last dimension should be divisible by 7, but got {shape[-1]}" + + packed_data = packed_data.contiguous().view(-1, 7) + + unpacked_data = torch.stack( + ( + (packed_data[:, 0] >> 1) & 0b01111111, + ((packed_data[:, 0] & 0b00000001) << 6) | ((packed_data[:, 1] >> 2) & 0b01111111), + ((packed_data[:, 1] & 0b00000011) << 5) | ((packed_data[:, 2] >> 3) & 0b01111111), + ((packed_data[:, 2] & 0b00000111) << 4) | ((packed_data[:, 3] >> 4) & 0b01111111), + ((packed_data[:, 3] & 0b00001111) << 3) | ((packed_data[:, 4] >> 5) & 0b01111111), + ((packed_data[:, 4] & 0b00011111) << 2) | ((packed_data[:, 5] >> 6) & 0b01111111), + ((packed_data[:, 5] & 0b00111111) << 1) | ((packed_data[:, 6] >> 7) & 0b01111111), + packed_data[:, 6] & 0b01111111, + ), + dim=-1, + ).view(up_size_uint7(shape)) + + return unpacked_data