From 6400a0ecffb6f14e9de20139aa5e4cace5acabdb Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Sun, 2 Jun 2024 23:44:40 -0400 Subject: [PATCH 01/11] untested unified pack/unpack --- torchao/prototype/common/bitpacking.py | 142 +++++++++++++------------ 1 file changed, 76 insertions(+), 66 deletions(-) diff --git a/torchao/prototype/common/bitpacking.py b/torchao/prototype/common/bitpacking.py index 35e471c347..6a8a92f636 100644 --- a/torchao/prototype/common/bitpacking.py +++ b/torchao/prototype/common/bitpacking.py @@ -1,15 +1,20 @@ import torch -from functools import reduce +from typing import Optional, Union +def mod_shape(shape, mod, dim): + """changes a select dimension of the input shape to mod""" + return (*shape[:dim], mod, *shape[dim+1:]) - -def unpack(data, data_size, by_rows = True, device="cuda"): +def unpack(data: torch.Tensor, + element_dtype: torch.dtype, + dimension: Optional[int] = 0, + device: Optional[str] ="cuda") -> torch.Tensor: """ Unpacks small dtype elements from a larger dtype. Inputs: - data: torch.Tensor - a tensor of packed elements of a small dtype within a larger dtype. - data_size: int - the size of the small dtype in bits. + data: - a tensor of packed elements of a small dtype within a larger dtype. + data_size: - the size of the small dtype in bits. optional: by_rows: bool - specifies whether to unpack... @@ -21,23 +26,46 @@ def unpack(data, data_size, by_rows = True, device="cuda"): Returns: torch.Tensor - a tensor of the unpacked elements. """ - if by_rows: - return _unpack_by_rows(data, data_size, device) - else: - return _unpack_by_cols(data, data_size) - -def pack(data, container_size, data_size, by_rows = True, device="cuda"): + element_size = torch.iinfo(element_dtype).bits + container_size = torch.iinfo(data.dtype).bits + scale = container_size // element_size + unpacked = _unpack(data, data_size, scale, dim, device) + if element_dtype == "trinary": + unpacked = unpacked.to(torch.int8) - 1 + return unpacked + +def _unpack(data, data_size, container_size, scale ,dim, device): + shape = data.shape + unpacked_data = torch.zeros(mod_shape(shape, shape[dim]*scale, dim), dtype=data.dtype).to(device) + nbits = (1 << data_size) - 1 # mask for the last dtype_size bits + unpacked_data = [] + for i in range(scale): + # add the next nbits to the unpacked data + shift_amt = container_size - data_size * (i + 1) + unpacked_data.append(((data >> shift_amt) & (nbits)).to(data.dtype)) + + # stack the unpacked data and reshape to the original shape + torch.stack(unpacked_data,dim=dim).view(mod_shape(shape,scale*shape[dim], dim)) + +def pack(data: torch.Tensor, + element_dtype: Union[torch.dtype, str], # accepting strings for trinary until thats added to torch + dimension: Optional[int] = 0, + container_dtype: Optional[torch.dtype] = None, + device: Optional[str] = "cuda") -> torch.Tensor: """ - Packs small dtype elements into a larger dtype. - Pads rows to be divisible by the scale. + Packs small dtype elements into a container of a larger dtype. + **Pads rows to be divisible by the scale** + TODO: support something like packing 8 uint 3s into 3 uint8s Inputs: - data: torch.Tensor - a tensor of unpacked elements of a small dtype. - container_size: int - the size of the large dtype in bits. - data_size: int - the size of the small dtype in bits. + data: - a tensor of unpacked elements of a small dtype. The dtype used for the data will be used for the container. + dimension: - the dimension to pack along + element_dtype: specify the dtype of the elements to pack optional: - by_rows: bool - specifies whether to pack values... + container_dtype: specify the dtype of the container if the data is not already inside a tensor of that dtype + + by_rows: specifies whether to pack values... by rows: tensor(n,m) -> tensor(n//scale, m) or by columns: tensor(n,m) -> tensor(n,m//scale) @@ -46,56 +74,38 @@ def pack(data, container_size, data_size, by_rows = True, device="cuda"): Returns: torch.Tensor - a tensor of packed elements. """ - if by_rows: - return _pack_by_rows(data, container_size, data_size, device) - else: - return _pack_by_cols(data, container_size, data_size, device) - -def _unpack_by_rows(data, data_size, device) -> torch.Tensor: - shape = data.shape - scale = data.element_size() * 8 // data_size - - unpacked_data = torch.zeros((shape[0]*scale, *shape[1:]), dtype=data.dtype).to(device) - nbits = (1 << data_size) - 1 # mask for the last dtype_size bits - for i in range(scale): - shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint - unpacked_data[i::scale] = ((data >> shift_amt) & (nbits)) - return unpacked_data + if container_dtype is not None: + data = data.to(container_dtype) + + if type(element_dtype) == str: + if element_dtype == "trinary": + data = data+1 + else: + raise ValueError(f"element_dtype {element_dtype} not supported") + + element_size = torch.iinfo(element_dtype).bits + container_size = torch.iinfo(data.dtype).bits + scale = container_size // element_size + assert data.shape[dimension] >= scale, f"not enough values to pack along dimension {dimension} ({data.shape[dimension]}) < scale ({scale})" + return _pack_uints(data, container_size, element_dtype, scale, dimension, device) -def _unpack_by_cols(data, data_size) -> torch.Tensor: - shape = data.shape - scale = data.element_size() * 8 // data_size - unpacked_data = [] - nbits = (1 << data_size) - 1 # mask for the last dtype_size bits + + +def _pack(data, container_size, data_size, scale, dim, device) -> torch.Tensor: + #pad dimension to be divisible by scale + if data.shape[dimension] % scale != 0: + padding = torch.zeros(mod_shape(data.shape, scale - data.shape[dim] % scale, dim), dtype=data.dtype).to(device) + data = torch.cat([data, padding], dim=dim).cuda() + + packed = torch.zeros(mod_shape(data.shape, data.shape[dim] // scale, dim), dtype=data.dtype).to(device) for i in range(scale): - shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint - unpacked_data.append(((data >> shift_amt) & (nbits)).to(data.dtype)) - return torch.stack(unpacked_data,dim=-1).view(*shape[:-1],shape[-1]*scale) # stack the unpacked data and reshape to the original shape + torch.arange(start=i, stop=data.shape[k], step=scale) + packed |= torch.index_select(data, dim=k, index=indices) << container_size-data_size*(i+1) + return packed -def _pack_by_rows(data, container_size, data_size, device) -> torch.Tensor: - - scale = container_size // data_size - assert scale > 1, f"container_size ({container_size}) is not larger than data_size ({data_size})" - assert data.shape[0] >= scale, f"not enough values to pack, data.shape[0] ({data.shape[0]}) < scale ({scale})" - # pad the data to be divisible by scale - if data.shape[0] % scale != 0: - padding = torch.zeros((scale - data.shape[0] % scale, *data.shape[1:],), dtype=data.dtype).to(device) - data = torch.cat([data, padding], dim=0).cuda() - - shape = data.shape - ret = reduce(lambda x,y: x|y,[data[i::scale, ...] << container_size-data_size*(i+1) for i in range(scale)]) - return ret.view(shape[0] // scale, *shape[1:]).to(device) -def _pack_by_cols(data, container_size, data_size, device) -> torch.Tensor: - scale = container_size // data_size - assert scale > 1, f"container_size ({container_size}) not double the capacity ofdata_size ({data_size})" - # pad the data to be divisible by scale - if data.shape[-1] % scale != 0: - padding = torch.zeros((*data.shape[:-1], scale - data.shape[-1] % scale), dtype=data.dtype).to(device) - data = torch.cat([data, padding], dim=-1).cuda() - - shape = data.shape - data = data.contiguous().view(-1) - #shift the data to the different indexes within the larger dtype and then union them together - ret = reduce(lambda x,y: x|y,[data[i::scale] << container_size-data_size*(i+1) for i in range(scale)]) - return ret.view(*shape[:-1],shape[-1] // scale).to(device) \ No newline at end of file +test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8).cuda() +packed = pack(test_tensor, torch.uint4) +unpacked = unpack(packed, torch.uint4) +unpadded = unpacked[:test_tensor.shape[0], ...] +assert(unpadded.allclose(test_tensor)) \ No newline at end of file From 889bd8f80e976ab69577c4b738e5af79bbb3b647 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Mon, 3 Jun 2024 13:28:29 -0400 Subject: [PATCH 02/11] tests written, issues fixed --- test/prototype/test_bitpacking.py | 87 +++++++------------ torchao/prototype/common/bitpacking.py | 114 ++++++++++++++----------- 2 files changed, 93 insertions(+), 108 deletions(-) diff --git a/test/prototype/test_bitpacking.py b/test/prototype/test_bitpacking.py index c1b60e07f8..c58e9fb772 100644 --- a/test/prototype/test_bitpacking.py +++ b/test/prototype/test_bitpacking.py @@ -1,5 +1,5 @@ import torch -from torchao.prototype.common.bitpacking import pack, unpack +from torchao.prototype.common.bitpacking import pack, unpack, dtype_to_bits import pytest from torch.utils._triton import has_triton from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 @@ -7,64 +7,35 @@ if not TORCH_VERSION_AFTER_2_4: pytest.skip("Unsupported PyTorch version", allow_module_level=True) -def test_uint4_to_uint8_CPU(): - test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8) - packed = pack(test_tensor, 8, 4, device='cpu') - unpacked = unpack(packed, 4, device='cpu') - unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) +def test_trinary_to_uint8_CPU(): + test_tensor = torch.randint(-1, 1, (4, 4, 4), dtype=torch.int32) + for i in range(len(test_tensor.shape)): + packed = pack(test_tensor, "trinary", dimension = i, container_dtype = torch.uint8, device='cpu') + unpacked = unpack(packed, "trinary", dimension = i, device='cpu') + assert(unpacked.to(torch.int32).allclose(test_tensor)) -def test_uint3_to_int16_col_wise_cpu(): - test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16) - packed = pack(test_tensor,16, 3, False, device='cpu') - unpacked = unpack(packed, 3, False, device='cpu') - unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) - +def test_to_uint8_CPU(): + for dtype in {torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7}: + test_tensor = torch.randint(0, 2**dtype_to_bits(dtype), (4, 4, 4), dtype=torch.uint8) + for i in range(len(test_tensor.shape)): + packed = pack(test_tensor, dtype, dimension = i, container_dtype = torch.uint8, device='cpu') + unpacked = unpack(packed, dtype, dimension = i, device='cpu') + assert unpacked.to(dtype).allclose(test_tensor), f"Failed for {dtype} on dim {i}" + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_uint4_to_uint8(): - test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8).cuda() - packed = pack(test_tensor, 8, 4) - unpacked = unpack(packed, 4) - unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") -def test_uint4_to_uint8_compile(): - torch._dynamo.config.specialize_int = True - pack_compiled = torch.compile(pack, fullgraph=True) - unpack_compiled = torch.compile(unpack, fullgraph=True) - test_tensor = torch.randint(0, 15, (3, 4), dtype=torch.uint8).cuda() - packed = pack_compiled(test_tensor, 8, 4) - unpacked = unpack_compiled(packed, 4) - unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_uint3_to_int16(): - test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16).cuda() - packed = pack(test_tensor,16, 3) - unpacked = unpack(packed, 3) - unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") -def test_uint2_to_uint8_col_wise_compile(): - torch._dynamo.config.specialize_int = True - pack_compiled = torch.compile(pack, fullgraph=True) - unpack_compiled = torch.compile(unpack, fullgraph=True) - test_tensor = torch.randint(0, 3, (8, 8), dtype=torch.uint8).cuda() - packed = pack_compiled(test_tensor, 8, 2, False) - unpacked = unpack_compiled(packed,2, False) - unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) +def test_trinary_to_uint8(): + test_tensor = torch.randint(-1, 1, (4, 4, 4), dtype=torch.int32).cuda() + for i in range(len(test_tensor.shape)): + packed = pack(test_tensor, "trinary", dimension = i, container_dtype = torch.uint8) + unpacked = unpack(packed, "trinary", dimension = i) + assert(unpacked.to(torch.int32).allclose(test_tensor)) + print('trinary passed') @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_uint3_to_int16_col_wise(): - test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16).cuda() - packed = pack(test_tensor,16, 3, False) - unpacked = unpack(packed, 3, False) - unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) \ No newline at end of file +def test_to_uint8(): + for dtype in {torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7}: + test_tensor = torch.randint(0, 2**dtype_to_bits(dtype), (4, 4, 4), dtype=torch.uint8).cuda() + for i in range(len(test_tensor.shape)): + packed = pack(test_tensor, dtype, dimension = i, container_dtype = torch.uint8) + unpacked = unpack(packed, dtype, dimension = i) + assert unpacked.allclose(test_tensor), f"Failed for {dtype} on dim {i}" \ No newline at end of file diff --git a/torchao/prototype/common/bitpacking.py b/torchao/prototype/common/bitpacking.py index 6a8a92f636..205af9bbdf 100644 --- a/torchao/prototype/common/bitpacking.py +++ b/torchao/prototype/common/bitpacking.py @@ -5,47 +5,72 @@ def mod_shape(shape, mod, dim): """changes a select dimension of the input shape to mod""" return (*shape[:dim], mod, *shape[dim+1:]) + +def dtype_to_bits(dtype): + '''returns the number of bits in a dtype''' + if dtype in {torch.uint2, 'trinary'}: + return 2 + elif dtype == torch.uint3: + return 3 + elif dtype == torch.uint4: + return 4 + elif dtype == torch.uint5: + return 5 + elif dtype == torch.uint6: + return 6 + elif dtype == torch.uint7: + return 7 + elif dtype in {torch.uint8, torch.int8}: + return 8 + elif dtype in {torch.uint16, torch.int16, torch.float16}: + return 16 + elif dtype in {torch.uint32, torch.int32, torch.float32}: + return 32 + elif dtype == {torch.uint64, torch.int64, torch.float64}: + return 64 + else: + raise ValueError(f"dtype {dtype} not supported (yet)") + def unpack(data: torch.Tensor, - element_dtype: torch.dtype, + element_dtype: Union[torch.dtype, str], # accepting strings for trinary until thats added to torch dimension: Optional[int] = 0, device: Optional[str] ="cuda") -> torch.Tensor: """ Unpacks small dtype elements from a larger dtype. Inputs: - data: - a tensor of packed elements of a small dtype within a larger dtype. - data_size: - the size of the small dtype in bits. + data: - a tensor of packed elements + element_dtype: - the dtype of the elements to unpack optional: - by_rows: bool - specifies whether to unpack... - by rows: tensor(n,m) -> tensor(n*scale, m) - or by columns: tensor(n,m) -> tensor(n,m*scale) - - defaults to rows because quantization is typically done by rows - but choose the version which matches how you quantize as this improves memory accesses/performance + dimension: - the dimension to unpack along + Returns: torch.Tensor - a tensor of the unpacked elements. """ - element_size = torch.iinfo(element_dtype).bits - container_size = torch.iinfo(data.dtype).bits + container_size = dtype_to_bits(data.dtype) + element_size = dtype_to_bits(element_dtype) scale = container_size // element_size - unpacked = _unpack(data, data_size, scale, dim, device) + + unpacked = _unpack(data, element_size, container_size, scale, dimension, device) + if element_dtype == "trinary": unpacked = unpacked.to(torch.int8) - 1 return unpacked -def _unpack(data, data_size, container_size, scale ,dim, device): +def _unpack(data, element_size, container_size, scale ,dim, device): shape = data.shape unpacked_data = torch.zeros(mod_shape(shape, shape[dim]*scale, dim), dtype=data.dtype).to(device) - nbits = (1 << data_size) - 1 # mask for the last dtype_size bits - unpacked_data = [] + nbits = (1 << element_size) - 1 # mask for the last dtype_size bits for i in range(scale): - # add the next nbits to the unpacked data - shift_amt = container_size - data_size * (i + 1) - unpacked_data.append(((data >> shift_amt) & (nbits)).to(data.dtype)) - + shift_amt = container_size - element_size * (i + 1) + slices = [slice(None)] * unpacked_data.ndim + slices[dim] = slice(i, None, scale) + unpacked_data[slices] = ((data >> shift_amt) & (nbits)).to(data.dtype) + # stack the unpacked data and reshape to the original shape - torch.stack(unpacked_data,dim=dim).view(mod_shape(shape,scale*shape[dim], dim)) + return unpacked_data.view(mod_shape(shape,scale*shape[dim], dim)) + def pack(data: torch.Tensor, element_dtype: Union[torch.dtype, str], # accepting strings for trinary until thats added to torch @@ -58,54 +83,43 @@ def pack(data: torch.Tensor, TODO: support something like packing 8 uint 3s into 3 uint8s Inputs: - data: - a tensor of unpacked elements of a small dtype. The dtype used for the data will be used for the container. - dimension: - the dimension to pack along - element_dtype: specify the dtype of the elements to pack + data: a tensor of unpacked elements of a small dtype. The dtype used for the data will be used for the container. + dimension: the dimension to pack along + element_dtype: the dtype of the elements to pack optional: container_dtype: specify the dtype of the container if the data is not already inside a tensor of that dtype - - by_rows: specifies whether to pack values... - by rows: tensor(n,m) -> tensor(n//scale, m) - or by columns: tensor(n,m) -> tensor(n,m//scale) + defaults to rows because quantization is typically done by rows but choose the version which matches how you quantize as this improves memory accesses/performance Returns: torch.Tensor - a tensor of packed elements. """ + if element_dtype == "trinary": + data = data + 1 + if container_dtype is not None: data = data.to(container_dtype) - - if type(element_dtype) == str: - if element_dtype == "trinary": - data = data+1 - else: - raise ValueError(f"element_dtype {element_dtype} not supported") - - element_size = torch.iinfo(element_dtype).bits - container_size = torch.iinfo(data.dtype).bits + + container_size = dtype_to_bits(data.dtype) + element_size = dtype_to_bits(element_dtype) scale = container_size // element_size + assert data.shape[dimension] >= scale, f"not enough values to pack along dimension {dimension} ({data.shape[dimension]}) < scale ({scale})" - return _pack_uints(data, container_size, element_dtype, scale, dimension, device) + return _pack(data, container_size, element_size, scale, dimension, device) -def _pack(data, container_size, data_size, scale, dim, device) -> torch.Tensor: +def _pack(data, container_size, element_size, scale, dim, device) -> torch.Tensor: #pad dimension to be divisible by scale - if data.shape[dimension] % scale != 0: + if data.shape[dim] % scale != 0: padding = torch.zeros(mod_shape(data.shape, scale - data.shape[dim] % scale, dim), dtype=data.dtype).to(device) - data = torch.cat([data, padding], dim=dim).cuda() + data = torch.cat([data, padding], dim=dim).to(device) packed = torch.zeros(mod_shape(data.shape, data.shape[dim] // scale, dim), dtype=data.dtype).to(device) for i in range(scale): - torch.arange(start=i, stop=data.shape[k], step=scale) - packed |= torch.index_select(data, dim=k, index=indices) << container_size-data_size*(i+1) - return packed - - -test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8).cuda() -packed = pack(test_tensor, torch.uint4) -unpacked = unpack(packed, torch.uint4) -unpadded = unpacked[:test_tensor.shape[0], ...] -assert(unpadded.allclose(test_tensor)) \ No newline at end of file + slices = [slice(None)] * packed.ndim + slices[dim] = slice(i, None, scale) + packed |= data[slices] << container_size-element_size*(i+1) + return packed \ No newline at end of file From 55d0db80640d324d508634a18c2f96854651cdfa Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Mon, 3 Jun 2024 13:41:12 -0400 Subject: [PATCH 03/11] removed conversion --- test/prototype/test_bitpacking.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/prototype/test_bitpacking.py b/test/prototype/test_bitpacking.py index c58e9fb772..8c37b2a0fe 100644 --- a/test/prototype/test_bitpacking.py +++ b/test/prototype/test_bitpacking.py @@ -20,7 +20,7 @@ def test_to_uint8_CPU(): for i in range(len(test_tensor.shape)): packed = pack(test_tensor, dtype, dimension = i, container_dtype = torch.uint8, device='cpu') unpacked = unpack(packed, dtype, dimension = i, device='cpu') - assert unpacked.to(dtype).allclose(test_tensor), f"Failed for {dtype} on dim {i}" + assert unpacked.allclose(test_tensor), f"Failed for {dtype} on dim {i}" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_trinary_to_uint8(): @@ -29,7 +29,6 @@ def test_trinary_to_uint8(): packed = pack(test_tensor, "trinary", dimension = i, container_dtype = torch.uint8) unpacked = unpack(packed, "trinary", dimension = i) assert(unpacked.to(torch.int32).allclose(test_tensor)) - print('trinary passed') @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_to_uint8(): @@ -38,4 +37,9 @@ def test_to_uint8(): for i in range(len(test_tensor.shape)): packed = pack(test_tensor, dtype, dimension = i, container_dtype = torch.uint8) unpacked = unpack(packed, dtype, dimension = i) - assert unpacked.allclose(test_tensor), f"Failed for {dtype} on dim {i}" \ No newline at end of file + assert unpacked.allclose(test_tensor), f"Failed for {dtype} on dim {i}" + +test_trinary_to_uint8_CPU() +test_to_uint8_CPU() +test_trinary_to_uint8() +test_to_uint8() \ No newline at end of file From 1c0ca9d8b773d0acc7b7c1c315e86dcba5eb97e6 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Mon, 3 Jun 2024 16:24:33 -0400 Subject: [PATCH 04/11] works with compile + use pytest params --- test/prototype/test_bitpacking.py | 101 +++++++++++++++++-------- torchao/prototype/common/bitpacking.py | 16 +++- 2 files changed, 85 insertions(+), 32 deletions(-) diff --git a/test/prototype/test_bitpacking.py b/test/prototype/test_bitpacking.py index 8c37b2a0fe..a1e8594a07 100644 --- a/test/prototype/test_bitpacking.py +++ b/test/prototype/test_bitpacking.py @@ -7,39 +7,78 @@ if not TORCH_VERSION_AFTER_2_4: pytest.skip("Unsupported PyTorch version", allow_module_level=True) -def test_trinary_to_uint8_CPU(): - test_tensor = torch.randint(-1, 1, (4, 4, 4), dtype=torch.int32) - for i in range(len(test_tensor.shape)): - packed = pack(test_tensor, "trinary", dimension = i, container_dtype = torch.uint8, device='cpu') - unpacked = unpack(packed, "trinary", dimension = i, device='cpu') - assert(unpacked.to(torch.int32).allclose(test_tensor)) +dtypes = (torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7, "trinary") +expected_pack_size = {torch.uint2: 1, torch.uint3: 2, torch.uint4: 2, torch.uint5: 4, torch.uint6: 4, torch.uint7: 4, "trinary": 1} +dimensions = (0, 1, 2) + +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("dim", dimensions) +def test_CPU(dtype, dim): + shape = [4, 4, 4] + if dtype == "trinary": + test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8, device='cpu') + else: + test_tensor = torch.randint(0, 2**dtype_to_bits(dtype), shape, dtype=torch.uint8, device='cpu') + + packed = pack(test_tensor, dtype, dimension = dim, container_dtype = torch.uint8, device='cpu') + assert(packed.shape[dim] == expected_pack_size[dtype]) + unpacked = unpack(packed, dtype, dimension = dim, device='cpu') + assert(unpacked.allclose(test_tensor)) -def test_to_uint8_CPU(): - for dtype in {torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7}: - test_tensor = torch.randint(0, 2**dtype_to_bits(dtype), (4, 4, 4), dtype=torch.uint8) - for i in range(len(test_tensor.shape)): - packed = pack(test_tensor, dtype, dimension = i, container_dtype = torch.uint8, device='cpu') - unpacked = unpack(packed, dtype, dimension = i, device='cpu') - assert unpacked.allclose(test_tensor), f"Failed for {dtype} on dim {i}" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_trinary_to_uint8(): - test_tensor = torch.randint(-1, 1, (4, 4, 4), dtype=torch.int32).cuda() - for i in range(len(test_tensor.shape)): - packed = pack(test_tensor, "trinary", dimension = i, container_dtype = torch.uint8) - unpacked = unpack(packed, "trinary", dimension = i) - assert(unpacked.to(torch.int32).allclose(test_tensor)) +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("dim", dimensions) +def test_GPU(dtype, dim): + shape = [4, 4, 4] + if dtype == "trinary": + test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda() + else: + test_tensor = torch.randint(0, 2**dtype_to_bits(dtype), shape, dtype=torch.uint8).cuda() + + packed = pack(test_tensor, dtype, dimension = dim, container_dtype = torch.uint8) + assert(packed.shape[dim] == expected_pack_size[dtype]) + unpacked = unpack(packed, dtype, dimension = dim) + assert(unpacked.allclose(test_tensor)) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_to_uint8(): - for dtype in {torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7}: - test_tensor = torch.randint(0, 2**dtype_to_bits(dtype), (4, 4, 4), dtype=torch.uint8).cuda() - for i in range(len(test_tensor.shape)): - packed = pack(test_tensor, dtype, dimension = i, container_dtype = torch.uint8) - unpacked = unpack(packed, dtype, dimension = i) - assert unpacked.allclose(test_tensor), f"Failed for {dtype} on dim {i}" - -test_trinary_to_uint8_CPU() -test_to_uint8_CPU() -test_trinary_to_uint8() -test_to_uint8() \ No newline at end of file +@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("dim", dimensions) +def test_compile(dtype, dim): + pack_compile = torch.compile(pack, fullgraph=True) + unpack_compile = torch.compile(unpack, fullgraph=True) + + shape = [4, 4, 4] + if dtype == "trinary": + test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda() + else: + test_tensor = torch.randint(0, 2**dtype_to_bits(dtype), shape, dtype=torch.uint8).cuda() + + packed = pack(test_tensor, dtype, dimension = dim, container_dtype = torch.uint8) + assert(packed.shape[dim] == expected_pack_size[dtype]) + unpacked = unpack(packed, dtype, dimension = dim) + assert(unpacked.allclose(test_tensor)) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("dim", dimensions) +def test_padding(dtype, dim): + pack_compile = torch.compile(pack, fullgraph=True) + unpack_compile = torch.compile(unpack, fullgraph=True) + + shape =[4, 4, 4] + shape[dim] = 5 + + if dtype == "trinary": + test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda() + else: + test_tensor = torch.randint(0, 2**dtype_to_bits(dtype), shape, dtype=torch.uint8).cuda() + + packed = pack(test_tensor, dtype, dimension = dim, container_dtype = torch.uint8) + assert(packed.shape[dim] == expected_pack_size[dtype]+1) # +1 for this scenario + unpacked = unpack(packed, dtype, dimension = dim) + slices = [slice(None)] * packed.ndim + slices[dim] = slice(None, 5) + assert(unpacked[slices].allclose(test_tensor)) \ No newline at end of file diff --git a/torchao/prototype/common/bitpacking.py b/torchao/prototype/common/bitpacking.py index 205af9bbdf..85f4662e56 100644 --- a/torchao/prototype/common/bitpacking.py +++ b/torchao/prototype/common/bitpacking.py @@ -122,4 +122,18 @@ def _pack(data, container_size, element_size, scale, dim, device) -> torch.Tenso slices = [slice(None)] * packed.ndim slices[dim] = slice(i, None, scale) packed |= data[slices] << container_size-element_size*(i+1) - return packed \ No newline at end of file + return packed + + +# shape = [5, 1] +# dtype= torch.uint2 +# test_tensor = torch.randint(0, 2, shape, dtype=torch.uint8).cuda() +# print(test_tensor) +# packed = pack(test_tensor, dtype, dimension = 0, container_dtype = torch.uint8) +# print(packed) +# unpacked = unpack(packed, dtype, dimension = 0) + +# slices = [slice(None)] * packed.ndim +# slices[0] = slice(None, 4+1) +# print(unpacked[slices]) +# assert(unpacked[slices].allclose(test_tensor)) \ No newline at end of file From adb3b34d0de2dbf041ef19a7fce3e19d2776914f Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 7 Jun 2024 18:03:57 -0400 Subject: [PATCH 05/11] added hqq int4 fp16 mixed matmul benchmark for pack --- benchmarks/benchmark_bitpacking.py | 275 +++++++++++++++++++++++++ test/prototype/test_bitpacking.py | 90 ++++---- torchao/prototype/common/bitpacking.py | 140 ++++++------- 3 files changed, 391 insertions(+), 114 deletions(-) create mode 100644 benchmarks/benchmark_bitpacking.py diff --git a/benchmarks/benchmark_bitpacking.py b/benchmarks/benchmark_bitpacking.py new file mode 100644 index 0000000000..50b46ca291 --- /dev/null +++ b/benchmarks/benchmark_bitpacking.py @@ -0,0 +1,275 @@ +# from torchao.quantization.quant_primitives import dynamically_quantize_per_channel +from torchao.prototype.common.bitpacking import pack, unpack +from math import log +# from torchao.utils import benchmark_utils +import torch +pack = torch.compile(pack, fullgraph=True) + +def benchmark(function, num_runs, *args, **kwargs): + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + + for _ in range(num_runs): + function(*args, **kwargs) + + end_event.record() + torch.cuda.synchronize() + return start_event.elapsed_time(end_event) / num_runs + + + +def load4x(scale=1024): + fake_tensor = torch.randint(0, 2**8, (1, 4*scale,scale), dtype=torch.uint8).cuda() + +def load2x(scale=1024): + fake_tensor = torch.randint(0, 2**8, (1, 2*scale,scale), dtype=torch.uint8).cuda() + +def loadx(scale=1024): + fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda() + +def unpack8to2(scale=1024): + fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda() + unpacked_tensor = unpack_c(fake_tensor, 2, dim=1) + +def unpack8to4(scale=1024): + fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda() + unpacked_tensor = unpack_c(fake_tensor, 4, dim=1) + +def t8to4wmm(scale=1024): + fake_tensor = torch.randint(0, 2**8, (8, 1024,1024), dtype=torch.uint8).cuda() + unpacked_tensor = unpack_c(fake_tensor, 4, dim=1) + +def test_iso_bitpack(): + torch._dynamo.config.specialize_int = True + # _unpack_c = torch.compile(_unpack, fullgraph=True) + unpack_c = torch.compile(unpack, fullgraph=True) + + scale = [16,64,256,1024,4096] + load4x_times = [] + unpack8to2_times = [] + load2x_times = [] + unpack8to4_times = [] + for s in scale: + res = benchmark(load4x, 50, scale=s) + load4x_times.append(res) + print(f"load(1, {4*s},{s}) time: {res} ms") + + res=benchmark(unpack8to2, 50, scale=s) + unpack8to2_times.append(res) + print(f"load(1, {s},{s}) unpack uint2 time: {res} ms") + + res = benchmark(load2x, 50, scale=s) + load2x_times.append(res) + print(f"load(1, {2*s},{s}) time: {res} ms") + + res = benchmark(unpack8to4, 50, scale=s) + unpack8to4_times.append(res) + print(f"load(1, {s},{s}) unpack uint4 time: {res} ms") + print() + + # import matplotlib.pyplot as plt + # plt.plot(scale, load4x_times, label="load(1, 4x, x)") + # plt.plot(scale, unpack8to2_times, label="unpack uint8 to uint2") + # plt.plot(scale, load2x_times, label="load(1, 2x, x)") + # plt.plot(scale, unpack8to4_times, label="unpack uint8 to uint4") + # plt.xlabel("scale") + # plt.ylabel("time (ms)") + # plt.yscale("log") + # plt.legend() + # plt.savefig("benchmark_bitpacking.png") + +import hqq +import hqq.core.quantize as hqq_quantize +HQQLinear = hqq_quantize.HQQLinear +BaseQuantizeConfig = hqq_quantize.BaseQuantizeConfig + +import itertools +from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm + +# Test configs +SHAPES = [ + [16, 128, 128], + [16, 4096, 4096], +] + +DTYPES = [torch.bfloat16, torch.float16] +GROUP_SIZES = [64, 128] +AXES = [1] # Only axis = 1 supported +TRANSPOSED = [False, True] +TRITON_KERNEL_TYPE = ["compute_bound"] # ["max_autotune", "compute_bound"] + +TEST_CONFIGS = list( + itertools.product(SHAPES, GROUP_SIZES, AXES, DTYPES, TRANSPOSED, TRITON_KERNEL_TYPE) +) + +BASE_QUANT_CONFIG = { + "optimize": True, + "view_as_float": False, + "nbits": 4, + "bitpack": False, + "axis": 1, +} + + +def check(expected, actual, msg="", max_diff=1e-3, verbose=False): + passed = torch.allclose(expected, actual, atol=max_diff, rtol=max_diff) + if verbose: + max_err = (expected - actual).abs().max() + if not passed: + print_msg = f"{msg}:\nFailed! Max error: {max_err}" + try: + from termcolor import colored + except ImportError: + print(print_msg) + else: + print(colored(print_msg, "red", attrs=["bold"])) + + else: + print_msg = f"{msg}:\nPassed! Max error: {max_err}" + try: + from termcolor import colored + except ImportError: + print(print_msg) + else: + print(colored(print_msg, "green", attrs=["bold"])) + + return passed + + +def test_mixed_mm( + shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8, pack_fn = True +): + qcfg = { + **BASE_QUANT_CONFIG, + **dict(group_size=group_size, axis=axis), + } + M, N, K = shape + + linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda") + + quant_config = BaseQuantizeConfig( + quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False + ) + quant_config.update({"weight_quant_params": qcfg}) + hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False) + W_q, meta = hqq_linear.W_q, hqq_linear.meta + W_q = W_q.to(dtype=quant_dtype) + W_q = ( + W_q.reshape(meta["shape"]) + if quant_config["weight_quant_params"]["bitpack"] == False + else W_q + ) + W_dq = hqq_linear.dequantize() + + scales, zeros = meta["scale"], meta["zero"] + scales = scales.reshape(N, -1) + zeros = zeros.reshape(N, -1) + if pack_fn: + packed_w = pack(W_q.T,4,dim=0,order=False) + else: + packed_w = pack_2xint4(W_q.T) + # print(W_q.T[0:5,0:5], W_q.T.shape) + # print(packed_w[0:5,0:5], W_q.T.shape) + # print(packed_w2[0:5,0:5], W_q.T.shape) + if transposed: + x = torch.randn(M, N, dtype=dtype, device="cuda") + hqq_out = x @ W_dq + + tt_out = triton_mixed_mm( + x, + packed_w, + scales.T, + zeros.T, + transposed=True, + group_size=group_size, + fp8_fast_accum=False, + kernel_type=kernel_type, + ) + + else: + x = torch.randn(M, K, dtype=dtype, device="cuda") + hqq_out = x @ W_dq.T + + tt_out = triton_mixed_mm( + x, + packed_w, + scales.T, + zeros.T, + transposed=False, + group_size=group_size, + fp8_fast_accum=False, + kernel_type=kernel_type, + ) + # assert check( + # hqq_out, + # tt_out, + # max_diff=1e-2 if dtype == torch.bfloat16 else 1e-3, + # verbose=True, + # ) + + +if __name__ == "__main__": + # _test_mixed_mm(transposed=False) + shapes = [ + [16, 128, 128], + [16, 4096, 4096], + ] + group_sizes = [64, 128] + shape = [16, 128, 128] + group_size = 64 + + for i in range(2): + shape = shapes[i] + group_size = group_sizes[i] + print("linear layer size: ", shape) + print("group size: ", group_size) + # run once to compile + test_mixed_mm( + shape, + group_size, + 1, + torch.float16, + True, + "compute_bound", + torch.uint8, + ) + # shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8 + print("pack time (ms): ", benchmark(test_mixed_mm, 10, + shape, + group_size, + 1, + torch.float16, + True, + "compute_bound", + torch.uint8)) + + print("pack_2xint4 time (ms): ", benchmark(test_mixed_mm, 10, + shape, + group_size, + 1, + torch.float16, + True, + "compute_bound", #max autotune doesnt work? + torch.uint8, + pack_fn=False)) + # print("pack_2xint4 time (ms): ", benchmark(test_mixed_mm, 10, + # shape, + # group_size, + # 1, + # torch.float16, + # True, + # "compute_bound", #max autotune doesnt work? + # torch.uint8, + # pack_fn=False)) + # print("pack time (ms): ", benchmark(test_mixed_mm, 10, + # shape, + # group_size, + # 1, + # torch.float16, + # True, + # "compute_bound", + # torch.uint8)) + print("") + diff --git a/test/prototype/test_bitpacking.py b/test/prototype/test_bitpacking.py index a1e8594a07..0434a2e202 100644 --- a/test/prototype/test_bitpacking.py +++ b/test/prototype/test_bitpacking.py @@ -1,5 +1,5 @@ import torch -from torchao.prototype.common.bitpacking import pack, unpack, dtype_to_bits +from torchao.prototype.common.bitpacking import pack, unpack import pytest from torch.utils._triton import has_triton from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 @@ -7,22 +7,23 @@ if not TORCH_VERSION_AFTER_2_4: pytest.skip("Unsupported PyTorch version", allow_module_level=True) -dtypes = (torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7, "trinary") -expected_pack_size = {torch.uint2: 1, torch.uint3: 2, torch.uint4: 2, torch.uint5: 4, torch.uint6: 4, torch.uint7: 4, "trinary": 1} -dimensions = (0, 1, 2) +dtypes = ((2, 'trinary', 1), (2, None, 1), (3, None, 2), (4, None, 2), (5, None, 4), (6, None, 4), (7, None, 4)) +dimensions = (2, 1, 0) + @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("dim", dimensions) def test_CPU(dtype, dim): + element_bit_width, element_type,expected_pack_size = dtype shape = [4, 4, 4] - if dtype == "trinary": + if element_type == "trinary": test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8, device='cpu') else: - test_tensor = torch.randint(0, 2**dtype_to_bits(dtype), shape, dtype=torch.uint8, device='cpu') + test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8, device='cpu') - packed = pack(test_tensor, dtype, dimension = dim, container_dtype = torch.uint8, device='cpu') - assert(packed.shape[dim] == expected_pack_size[dtype]) - unpacked = unpack(packed, dtype, dimension = dim, device='cpu') + packed = pack(test_tensor, element_bit_width, element_type=element_type, dim = dim, container_dtype = torch.uint8, device='cpu') + assert(packed.shape[dim] == expected_pack_size) + unpacked = unpack(packed, element_bit_width, element_type=element_type, dim = dim, device='cpu') assert(unpacked.allclose(test_tensor)) @@ -30,55 +31,64 @@ def test_CPU(dtype, dim): @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("dim", dimensions) def test_GPU(dtype, dim): + element_bit_width, element_type,expected_pack_size = dtype shape = [4, 4, 4] - if dtype == "trinary": + if element_type == "trinary": test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda() else: - test_tensor = torch.randint(0, 2**dtype_to_bits(dtype), shape, dtype=torch.uint8).cuda() + test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8).cuda() - packed = pack(test_tensor, dtype, dimension = dim, container_dtype = torch.uint8) - assert(packed.shape[dim] == expected_pack_size[dtype]) - unpacked = unpack(packed, dtype, dimension = dim) + packed = pack(test_tensor, element_bit_width, element_type=element_type, dim = dim, container_dtype = torch.uint8) + assert(packed.shape[dim] == expected_pack_size) + unpacked = unpack(packed, element_bit_width, element_type=element_type, dim = dim) assert(unpacked.allclose(test_tensor)) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("dim", dimensions) -def test_compile(dtype, dim): - pack_compile = torch.compile(pack, fullgraph=True) - unpack_compile = torch.compile(unpack, fullgraph=True) +def test_padding(dtype, dim): + element_bit_width, element_type,expected_pack_size = dtype + torch._dynamo.config.specialize_int = True + shape =[4, 4, 4] + shape[dim] = 5 - shape = [4, 4, 4] - if dtype == "trinary": + if element_type == "trinary": test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda() else: - test_tensor = torch.randint(0, 2**dtype_to_bits(dtype), shape, dtype=torch.uint8).cuda() + test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8).cuda() - packed = pack(test_tensor, dtype, dimension = dim, container_dtype = torch.uint8) - assert(packed.shape[dim] == expected_pack_size[dtype]) - unpacked = unpack(packed, dtype, dimension = dim) - assert(unpacked.allclose(test_tensor)) - + packed = pack(test_tensor, + element_bit_width, + element_type=element_type, + dim = dim, + container_dtype = torch.uint8, + pad= True) + assert packed.shape[dim] == expected_pack_size+1, f"packed.shape[dim] {packed.shape[dim]}" # +1 for this scenario + unpacked = unpack(packed, element_bit_width, element_type=element_type, dim = dim) + slices = [slice(None)] * packed.ndim + slices[dim] = slice(None, 5) + assert unpacked[slices].allclose(test_tensor) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("dim", dimensions) -def test_padding(dtype, dim): - pack_compile = torch.compile(pack, fullgraph=True) - unpack_compile = torch.compile(unpack, fullgraph=True) - - shape =[4, 4, 4] - shape[dim] = 5 - - if dtype == "trinary": +def test_compile(dtype, dim): + pack_compile = torch.compile(pack, fullgraph=True, dynamic=True) + unpack_compile = torch.compile(unpack, fullgraph=True, dynamic=True) + element_bit_width, element_type,expected_pack_size = dtype + torch._dynamo.config.specialize_int = True + shape = [4, 4, 4] + if element_type == "trinary": test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda() else: - test_tensor = torch.randint(0, 2**dtype_to_bits(dtype), shape, dtype=torch.uint8).cuda() + test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.int8).cuda() - packed = pack(test_tensor, dtype, dimension = dim, container_dtype = torch.uint8) - assert(packed.shape[dim] == expected_pack_size[dtype]+1) # +1 for this scenario - unpacked = unpack(packed, dtype, dimension = dim) - slices = [slice(None)] * packed.ndim - slices[dim] = slice(None, 5) - assert(unpacked[slices].allclose(test_tensor)) \ No newline at end of file + packed = pack_compile(test_tensor, element_bit_width, element_type=element_type, dim = dim, container_dtype = torch.int8) + assert(packed.shape[dim] == expected_pack_size) + unpacked = unpack_compile(packed, element_bit_width, element_type=element_type, dim = dim) + assert(unpacked.allclose(test_tensor)) + diff --git a/torchao/prototype/common/bitpacking.py b/torchao/prototype/common/bitpacking.py index 85f4662e56..2278e6d27a 100644 --- a/torchao/prototype/common/bitpacking.py +++ b/torchao/prototype/common/bitpacking.py @@ -4,58 +4,36 @@ def mod_shape(shape, mod, dim): """changes a select dimension of the input shape to mod""" return (*shape[:dim], mod, *shape[dim+1:]) - - -def dtype_to_bits(dtype): - '''returns the number of bits in a dtype''' - if dtype in {torch.uint2, 'trinary'}: - return 2 - elif dtype == torch.uint3: - return 3 - elif dtype == torch.uint4: - return 4 - elif dtype == torch.uint5: - return 5 - elif dtype == torch.uint6: - return 6 - elif dtype == torch.uint7: - return 7 - elif dtype in {torch.uint8, torch.int8}: - return 8 - elif dtype in {torch.uint16, torch.int16, torch.float16}: - return 16 - elif dtype in {torch.uint32, torch.int32, torch.float32}: - return 32 - elif dtype == {torch.uint64, torch.int64, torch.float64}: - return 64 - else: - raise ValueError(f"dtype {dtype} not supported (yet)") def unpack(data: torch.Tensor, - element_dtype: Union[torch.dtype, str], # accepting strings for trinary until thats added to torch - dimension: Optional[int] = 0, + element_bit_width: int, + element_type: Optional[str] = None, + dim: Optional[int] = 0, + output_dtype: Optional[torch.dtype] = None, device: Optional[str] ="cuda") -> torch.Tensor: """ Unpacks small dtype elements from a larger dtype. Inputs: data: - a tensor of packed elements - element_dtype: - the dtype of the elements to unpack + element_bit_width: the size in bits of the elements to unpack optional: - dimension: - the dimension to unpack along - - + element_type: the dtype of the elements to unpack (uint,trinary,float, etc) + dimension: the dimension to unpack along + output_dtype: specify the dtype of the output tensor if it is not the same as the input tensor + Returns: torch.Tensor - a tensor of the unpacked elements. """ - container_size = dtype_to_bits(data.dtype) - element_size = dtype_to_bits(element_dtype) - scale = container_size // element_size - - unpacked = _unpack(data, element_size, container_size, scale, dimension, device) + container_size = torch.iinfo(data.dtype).bits + scale = container_size // element_bit_width - if element_dtype == "trinary": + unpacked = _unpack(data, element_bit_width, container_size, scale, dim, device) + if element_type == "trinary": unpacked = unpacked.to(torch.int8) - 1 + elif output_dtype is not None: + unpacked = unpacked.to(output_dtype) + return unpacked def _unpack(data, element_size, container_size, scale ,dim, device): @@ -73,9 +51,12 @@ def _unpack(data, element_size, container_size, scale ,dim, device): def pack(data: torch.Tensor, - element_dtype: Union[torch.dtype, str], # accepting strings for trinary until thats added to torch - dimension: Optional[int] = 0, + element_bit_width: int, + element_type: Optional[str] = None, + dim: Optional[int] = 0, container_dtype: Optional[torch.dtype] = None, + pad: Optional[bool] = False, + order: Optional[bool] = True, device: Optional[str] = "cuda") -> torch.Tensor: """ Packs small dtype elements into a container of a larger dtype. @@ -84,56 +65,67 @@ def pack(data: torch.Tensor, Inputs: data: a tensor of unpacked elements of a small dtype. The dtype used for the data will be used for the container. - dimension: the dimension to pack along + dim: the dimension to pack along element_dtype: the dtype of the elements to pack - - optional: container_dtype: specify the dtype of the container if the data is not already inside a tensor of that dtype - - - defaults to rows because quantization is typically done by rows - but choose the version which matches how you quantize as this improves memory accesses/performance + pad: if set to true, pads the dimension to be divisible by the scale + order: if set to true, packs elements such that the lower index elements occupy the most significant bits Returns: torch.Tensor - a tensor of packed elements. """ - if element_dtype == "trinary": + + if element_type == "trinary": data = data + 1 if container_dtype is not None: data = data.to(container_dtype) - container_size = dtype_to_bits(data.dtype) - element_size = dtype_to_bits(element_dtype) - scale = container_size // element_size + container_size = torch.iinfo(data.dtype).bits + scale = container_size // element_bit_width + + if pad and data.shape[dim] % scale != 0: + padding = torch.zeros(mod_shape(data.shape, scale - data.shape[dim] % scale, dim), dtype=data.dtype).to(device) + data = torch.cat([data, padding], dim=dim).to(device) - assert data.shape[dimension] >= scale, f"not enough values to pack along dimension {dimension} ({data.shape[dimension]}) < scale ({scale})" - return _pack(data, container_size, element_size, scale, dimension, device) + + torch._assert(data.shape[dim] >= scale, f"not enough values to pack along dimension {dim}") + torch._assert(data.shape[dim] % scale == 0, "size of pack dimension not divisble by scale") + return _pack(data, container_size, element_bit_width, scale, dim, order, device) -def _pack(data, container_size, element_size, scale, dim, device) -> torch.Tensor: - #pad dimension to be divisible by scale - if data.shape[dim] % scale != 0: - padding = torch.zeros(mod_shape(data.shape, scale - data.shape[dim] % scale, dim), dtype=data.dtype).to(device) - data = torch.cat([data, padding], dim=dim).to(device) - +def _pack(data, container_size, element_bit_width, scale, dim, order, device) -> torch.Tensor: packed = torch.zeros(mod_shape(data.shape, data.shape[dim] // scale, dim), dtype=data.dtype).to(device) + slices = [slice(None)] * packed.ndim for i in range(scale): - slices = [slice(None)] * packed.ndim slices[dim] = slice(i, None, scale) - packed |= data[slices] << container_size-element_size*(i+1) + if order: + packed |= data[slices] << container_size-element_bit_width*(i+1) + else: + packed |= data[slices] << element_bit_width*i return packed - -# shape = [5, 1] -# dtype= torch.uint2 -# test_tensor = torch.randint(0, 2, shape, dtype=torch.uint8).cuda() -# print(test_tensor) -# packed = pack(test_tensor, dtype, dimension = 0, container_dtype = torch.uint8) -# print(packed) -# unpacked = unpack(packed, dtype, dimension = 0) - -# slices = [slice(None)] * packed.ndim -# slices[0] = slice(None, 4+1) -# print(unpacked[slices]) -# assert(unpacked[slices].allclose(test_tensor)) \ No newline at end of file +if __name__ == '__main__': + pack_compile = torch.compile(pack, fullgraph=True) + unpack_compile = torch.compile(unpack, fullgraph=True) + torch._dynamo.config.specialize_int = True + element_bit_width = 2 + element_type = "trinary" + dim = 0 + shape =[4, 4, 4] + shape[dim] = 5 + + if element_type == "trinary": + test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda() + else: + test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8).cuda() + + packed = pack_compile(test_tensor, element_bit_width, element_type=element_type, dim = dim, container_dtype = torch.uint8, pad= True) + print(packed.shape) + assert(packed.shape[dim] == 2) # +1 for this scenario + unpacked = unpack_compile(packed, element_bit_width, element_type=element_type, dim = dim) + slices = [slice(None)] * packed.ndim + slices[dim] = slice(None, 5) + print(test_tensor, "\n", packed,"\n",unpacked[slices]) + assert(unpacked[slices].allclose(test_tensor)) + From cb14444de0af979a30bc6af819f77b4cd95364f4 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Fri, 7 Jun 2024 18:09:02 -0400 Subject: [PATCH 06/11] added more repeats for benchmark and removed unused vars --- benchmarks/benchmark_bitpacking.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/benchmarks/benchmark_bitpacking.py b/benchmarks/benchmark_bitpacking.py index 50b46ca291..eb9dcb590a 100644 --- a/benchmarks/benchmark_bitpacking.py +++ b/benchmarks/benchmark_bitpacking.py @@ -88,21 +88,6 @@ def test_iso_bitpack(): import itertools from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm -# Test configs -SHAPES = [ - [16, 128, 128], - [16, 4096, 4096], -] - -DTYPES = [torch.bfloat16, torch.float16] -GROUP_SIZES = [64, 128] -AXES = [1] # Only axis = 1 supported -TRANSPOSED = [False, True] -TRITON_KERNEL_TYPE = ["compute_bound"] # ["max_autotune", "compute_bound"] - -TEST_CONFIGS = list( - itertools.product(SHAPES, GROUP_SIZES, AXES, DTYPES, TRANSPOSED, TRITON_KERNEL_TYPE) -) BASE_QUANT_CONFIG = { "optimize": True, @@ -236,7 +221,7 @@ def test_mixed_mm( torch.uint8, ) # shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8 - print("pack time (ms): ", benchmark(test_mixed_mm, 10, + print("pack time (ms): ", benchmark(test_mixed_mm, 100, shape, group_size, 1, @@ -245,7 +230,7 @@ def test_mixed_mm( "compute_bound", torch.uint8)) - print("pack_2xint4 time (ms): ", benchmark(test_mixed_mm, 10, + print("pack_2xint4 time (ms): ", benchmark(test_mixed_mm, 100, shape, group_size, 1, From 496d83af4b9683d538794d440cb00cd80cb1a35a Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Mon, 10 Jun 2024 00:37:41 -0400 Subject: [PATCH 07/11] added 1 more benchmark and now tests pass --- benchmarks/benchmark_bitpacking.py | 50 ++++++++++++++++-------------- test/prototype/test_bitpacking.py | 12 +++++++ 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/benchmarks/benchmark_bitpacking.py b/benchmarks/benchmark_bitpacking.py index eb9dcb590a..cd4d8ab228 100644 --- a/benchmarks/benchmark_bitpacking.py +++ b/benchmarks/benchmark_bitpacking.py @@ -1,9 +1,10 @@ # from torchao.quantization.quant_primitives import dynamically_quantize_per_channel from torchao.prototype.common.bitpacking import pack, unpack +from torchao.dtypes.uint4 import unpack_uint4, pack_uint4 from math import log # from torchao.utils import benchmark_utils import torch -pack = torch.compile(pack, fullgraph=True) + def benchmark(function, num_runs, *args, **kwargs): torch.cuda.synchronize() @@ -18,8 +19,24 @@ def benchmark(function, num_runs, *args, **kwargs): torch.cuda.synchronize() return start_event.elapsed_time(end_event) / num_runs +def test_existing(): + def new_(): + fake_tensor = torch.randint(0, 2**8, (1, 1024,1024), dtype=torch.uint8).cuda() + packed = pack(fake_tensor, 4, dim=1) + unpacked = unpack(packed, 4, dim=1) + def old_(): + fake_tensor = torch.randint(0, 2**8, (1, 1024,1024), dtype=torch.uint8).cuda() + packed = pack_uint4(fake_tensor) + unpacked = unpack_uint4(packed) + new_ = torch.compile(new_, fullgraph=True) + old_ = torch.compile(old_, fullgraph=True) + new_() + old_() + print(f"new: {benchmark(new_, 1000)} ms ") + print(f"old: {benchmark(old_, 1000)} ms") - + + def load4x(scale=1024): fake_tensor = torch.randint(0, 2**8, (1, 4*scale,scale), dtype=torch.uint8).cuda() @@ -123,7 +140,7 @@ def check(expected, actual, msg="", max_diff=1e-3, verbose=False): return passed -def test_mixed_mm( +def mixed_mm( shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8, pack_fn = True ): qcfg = { @@ -194,9 +211,8 @@ def test_mixed_mm( # verbose=True, # ) - -if __name__ == "__main__": - # _test_mixed_mm(transposed=False) + +def test_vs_hqqpack(): shapes = [ [16, 128, 128], [16, 4096, 4096], @@ -204,7 +220,7 @@ def test_mixed_mm( group_sizes = [64, 128] shape = [16, 128, 128] group_size = 64 - + pack = torch.compile(pack, fullgraph=True) for i in range(2): shape = shapes[i] group_size = group_sizes[i] @@ -239,22 +255,8 @@ def test_mixed_mm( "compute_bound", #max autotune doesnt work? torch.uint8, pack_fn=False)) - # print("pack_2xint4 time (ms): ", benchmark(test_mixed_mm, 10, - # shape, - # group_size, - # 1, - # torch.float16, - # True, - # "compute_bound", #max autotune doesnt work? - # torch.uint8, - # pack_fn=False)) - # print("pack time (ms): ", benchmark(test_mixed_mm, 10, - # shape, - # group_size, - # 1, - # torch.float16, - # True, - # "compute_bound", - # torch.uint8)) print("") + +if __name__ == "__main__": + test_existing() diff --git a/test/prototype/test_bitpacking.py b/test/prototype/test_bitpacking.py index 0434a2e202..144a767582 100644 --- a/test/prototype/test_bitpacking.py +++ b/test/prototype/test_bitpacking.py @@ -10,6 +10,18 @@ dtypes = ((2, 'trinary', 1), (2, None, 1), (3, None, 2), (4, None, 2), (5, None, 4), (6, None, 4), (7, None, 4)) dimensions = (2, 1, 0) +@pytest.fixture(autouse=True) +def run_before_and_after_tests(): + # source: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test # noqa: E501 + + # setup (currently do nothing) + + # tests will run here + yield + + # teardown + # avoid dynamo cache limit issues + torch._dynamo.reset() @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("dim", dimensions) From c566d139aeff88e943c000143d9e7603a8723010 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Mon, 10 Jun 2024 00:54:45 -0400 Subject: [PATCH 08/11] added order to unpack and updated tests --- test/prototype/test_bitpacking.py | 61 +++++++++++++++++++++----- torchao/prototype/common/bitpacking.py | 16 ++++--- 2 files changed, 59 insertions(+), 18 deletions(-) diff --git a/test/prototype/test_bitpacking.py b/test/prototype/test_bitpacking.py index 144a767582..21e462627a 100644 --- a/test/prototype/test_bitpacking.py +++ b/test/prototype/test_bitpacking.py @@ -9,6 +9,7 @@ dtypes = ((2, 'trinary', 1), (2, None, 1), (3, None, 2), (4, None, 2), (5, None, 4), (6, None, 4), (7, None, 4)) dimensions = (2, 1, 0) +orders = (True, False) @pytest.fixture(autouse=True) def run_before_and_after_tests(): @@ -25,7 +26,8 @@ def run_before_and_after_tests(): @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("dim", dimensions) -def test_CPU(dtype, dim): +@pytest.mark.parametrize("order", orders) +def test_CPU(dtype, dim, order): element_bit_width, element_type,expected_pack_size = dtype shape = [4, 4, 4] if element_type == "trinary": @@ -33,16 +35,28 @@ def test_CPU(dtype, dim): else: test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8, device='cpu') - packed = pack(test_tensor, element_bit_width, element_type=element_type, dim = dim, container_dtype = torch.uint8, device='cpu') + packed = pack(test_tensor, + element_bit_width, + element_type=element_type, + dim = dim, + order = order, + container_dtype = torch.uint8, + device='cpu') assert(packed.shape[dim] == expected_pack_size) - unpacked = unpack(packed, element_bit_width, element_type=element_type, dim = dim, device='cpu') + unpacked = unpack(packed, + element_bit_width, + element_type=element_type, + dim = dim, + order = order, + device='cpu') assert(unpacked.allclose(test_tensor)) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("dim", dimensions) -def test_GPU(dtype, dim): +@pytest.mark.parametrize("order", orders) +def test_GPU(dtype, dim, order): element_bit_width, element_type,expected_pack_size = dtype shape = [4, 4, 4] if element_type == "trinary": @@ -50,9 +64,18 @@ def test_GPU(dtype, dim): else: test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8).cuda() - packed = pack(test_tensor, element_bit_width, element_type=element_type, dim = dim, container_dtype = torch.uint8) + packed = pack(test_tensor, + element_bit_width, + element_type=element_type, + dim = dim, + order = order, + container_dtype = torch.uint8) assert(packed.shape[dim] == expected_pack_size) - unpacked = unpack(packed, element_bit_width, element_type=element_type, dim = dim) + unpacked = unpack(packed, + element_bit_width, + element_type=element_type, + order = order, + dim = dim) assert(unpacked.allclose(test_tensor)) @@ -60,7 +83,8 @@ def test_GPU(dtype, dim): @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("dim", dimensions) -def test_padding(dtype, dim): +@pytest.mark.parametrize("order", orders) +def test_padding(dtype, dim, order): element_bit_width, element_type,expected_pack_size = dtype torch._dynamo.config.specialize_int = True shape =[4, 4, 4] @@ -76,9 +100,14 @@ def test_padding(dtype, dim): element_type=element_type, dim = dim, container_dtype = torch.uint8, + order = order, pad= True) assert packed.shape[dim] == expected_pack_size+1, f"packed.shape[dim] {packed.shape[dim]}" # +1 for this scenario - unpacked = unpack(packed, element_bit_width, element_type=element_type, dim = dim) + unpacked = unpack(packed, + element_bit_width, + element_type=element_type, + dim = dim, + order = order) slices = [slice(None)] * packed.ndim slices[dim] = slice(None, 5) assert unpacked[slices].allclose(test_tensor) @@ -88,7 +117,8 @@ def test_padding(dtype, dim): @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("dim", dimensions) -def test_compile(dtype, dim): +@pytest.mark.parametrize("order", orders) +def test_compile(dtype, dim, order): pack_compile = torch.compile(pack, fullgraph=True, dynamic=True) unpack_compile = torch.compile(unpack, fullgraph=True, dynamic=True) element_bit_width, element_type,expected_pack_size = dtype @@ -99,8 +129,17 @@ def test_compile(dtype, dim): else: test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.int8).cuda() - packed = pack_compile(test_tensor, element_bit_width, element_type=element_type, dim = dim, container_dtype = torch.int8) + packed = pack_compile(test_tensor, element_bit_width, + element_type=element_type, + dim = dim, + container_dtype = torch.int8, + order = order) assert(packed.shape[dim] == expected_pack_size) - unpacked = unpack_compile(packed, element_bit_width, element_type=element_type, dim = dim) + unpacked = unpack_compile(packed, + element_bit_width, + element_type=element_type, + dim = dim, + order = order) assert(unpacked.allclose(test_tensor)) + diff --git a/torchao/prototype/common/bitpacking.py b/torchao/prototype/common/bitpacking.py index 2278e6d27a..8f2c3d0d60 100644 --- a/torchao/prototype/common/bitpacking.py +++ b/torchao/prototype/common/bitpacking.py @@ -9,6 +9,7 @@ def unpack(data: torch.Tensor, element_bit_width: int, element_type: Optional[str] = None, dim: Optional[int] = 0, + order: Optional[bool] = True, output_dtype: Optional[torch.dtype] = None, device: Optional[str] ="cuda") -> torch.Tensor: """ @@ -17,18 +18,16 @@ def unpack(data: torch.Tensor, Inputs: data: - a tensor of packed elements element_bit_width: the size in bits of the elements to unpack - - optional: element_type: the dtype of the elements to unpack (uint,trinary,float, etc) - dimension: the dimension to unpack along + dim: the dimension to unpack along output_dtype: specify the dtype of the output tensor if it is not the same as the input tensor - + order: make sure it matches the value set in the pack function Returns: torch.Tensor - a tensor of the unpacked elements. """ container_size = torch.iinfo(data.dtype).bits scale = container_size // element_bit_width - unpacked = _unpack(data, element_bit_width, container_size, scale, dim, device) + unpacked = _unpack(data, element_bit_width, container_size, scale, order, dim, device) if element_type == "trinary": unpacked = unpacked.to(torch.int8) - 1 elif output_dtype is not None: @@ -36,12 +35,15 @@ def unpack(data: torch.Tensor, return unpacked -def _unpack(data, element_size, container_size, scale ,dim, device): +def _unpack(data, element_size, container_size, scale, order, dim, device): shape = data.shape unpacked_data = torch.zeros(mod_shape(shape, shape[dim]*scale, dim), dtype=data.dtype).to(device) nbits = (1 << element_size) - 1 # mask for the last dtype_size bits for i in range(scale): - shift_amt = container_size - element_size * (i + 1) + if order: + shift_amt = container_size - element_size * (i + 1) + else: + shift_amt = element_size * i slices = [slice(None)] * unpacked_data.ndim slices[dim] = slice(i, None, scale) unpacked_data[slices] = ((data >> shift_amt) & (nbits)).to(data.dtype) From 3363b77e85d6a5bc2a10337fbd528bfae7793851 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Mon, 10 Jun 2024 12:52:31 -0400 Subject: [PATCH 09/11] removed main code and added a text example --- benchmarks/benchmark_bitpacking.py | 5 ++-- torchao/prototype/common/bitpacking.py | 37 ++++++++------------------ 2 files changed, 14 insertions(+), 28 deletions(-) diff --git a/benchmarks/benchmark_bitpacking.py b/benchmarks/benchmark_bitpacking.py index cd4d8ab228..7d8491656c 100644 --- a/benchmarks/benchmark_bitpacking.py +++ b/benchmarks/benchmark_bitpacking.py @@ -6,14 +6,15 @@ import torch -def benchmark(function, num_runs, *args, **kwargs): +def benchmark(setup, function, num_runs): + args = setup() torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(num_runs): - function(*args, **kwargs) + function(*args) end_event.record() torch.cuda.synchronize() diff --git a/torchao/prototype/common/bitpacking.py b/torchao/prototype/common/bitpacking.py index 8f2c3d0d60..db05f57b52 100644 --- a/torchao/prototype/common/bitpacking.py +++ b/torchao/prototype/common/bitpacking.py @@ -22,6 +22,7 @@ def unpack(data: torch.Tensor, dim: the dimension to unpack along output_dtype: specify the dtype of the output tensor if it is not the same as the input tensor order: make sure it matches the value set in the pack function + Returns: torch.Tensor - a tensor of the unpacked elements. """ container_size = torch.iinfo(data.dtype).bits @@ -62,8 +63,16 @@ def pack(data: torch.Tensor, device: Optional[str] = "cuda") -> torch.Tensor: """ Packs small dtype elements into a container of a larger dtype. - **Pads rows to be divisible by the scale** - TODO: support something like packing 8 uint 3s into 3 uint8s + For example, packing 4-bit elements into 8-bit containers. + along dimension 0: along dimension 1: + (0, 9, B, 4) --> ( 9, B4) + (3, 8, F, C) --> (38, FC) + | | | | + v v v v + (3, 98, BF, 4C) + + if order was set to false: + (30, 89, FB, C4) Inputs: data: a tensor of unpacked elements of a small dtype. The dtype used for the data will be used for the container. @@ -106,28 +115,4 @@ def _pack(data, container_size, element_bit_width, scale, dim, order, device) -> else: packed |= data[slices] << element_bit_width*i return packed - -if __name__ == '__main__': - pack_compile = torch.compile(pack, fullgraph=True) - unpack_compile = torch.compile(unpack, fullgraph=True) - torch._dynamo.config.specialize_int = True - element_bit_width = 2 - element_type = "trinary" - dim = 0 - shape =[4, 4, 4] - shape[dim] = 5 - - if element_type == "trinary": - test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda() - else: - test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8).cuda() - - packed = pack_compile(test_tensor, element_bit_width, element_type=element_type, dim = dim, container_dtype = torch.uint8, pad= True) - print(packed.shape) - assert(packed.shape[dim] == 2) # +1 for this scenario - unpacked = unpack_compile(packed, element_bit_width, element_type=element_type, dim = dim) - slices = [slice(None)] * packed.ndim - slices[dim] = slice(None, 5) - print(test_tensor, "\n", packed,"\n",unpacked[slices]) - assert(unpacked[slices].allclose(test_tensor)) From 9ef4c6c69719641f9e200fdd4404b8be38f3c604 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Mon, 10 Jun 2024 12:55:10 -0400 Subject: [PATCH 10/11] added example --- torchao/prototype/common/bitpacking.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/torchao/prototype/common/bitpacking.py b/torchao/prototype/common/bitpacking.py index db05f57b52..60009d0e63 100644 --- a/torchao/prototype/common/bitpacking.py +++ b/torchao/prototype/common/bitpacking.py @@ -63,16 +63,6 @@ def pack(data: torch.Tensor, device: Optional[str] = "cuda") -> torch.Tensor: """ Packs small dtype elements into a container of a larger dtype. - For example, packing 4-bit elements into 8-bit containers. - along dimension 0: along dimension 1: - (0, 9, B, 4) --> ( 9, B4) - (3, 8, F, C) --> (38, FC) - | | | | - v v v v - (3, 98, BF, 4C) - - if order was set to false: - (30, 89, FB, C4) Inputs: data: a tensor of unpacked elements of a small dtype. The dtype used for the data will be used for the container. @@ -83,6 +73,18 @@ def pack(data: torch.Tensor, order: if set to true, packs elements such that the lower index elements occupy the most significant bits Returns: torch.Tensor - a tensor of packed elements. + + + For example, packing 4-bit elements into 8-bit containers. + along dimension 0: along dimension 1: + (0, 9, B, 4) --> ( 9, B4) + (3, 8, F, C) --> (38, FC) + | | | | + v v v v + (3, 98, BF, 4C) + + if order was set to false: + (30, 89, FB, C4) """ if element_type == "trinary": From d9a94c8aca0c3c219ea75cca1ab141c23913ad1f Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Mon, 10 Jun 2024 13:07:04 -0400 Subject: [PATCH 11/11] organized benchmarks --- benchmarks/benchmark_bitpacking.py | 238 ++++++++++++----------------- 1 file changed, 101 insertions(+), 137 deletions(-) diff --git a/benchmarks/benchmark_bitpacking.py b/benchmarks/benchmark_bitpacking.py index 7d8491656c..e974efca58 100644 --- a/benchmarks/benchmark_bitpacking.py +++ b/benchmarks/benchmark_bitpacking.py @@ -1,12 +1,11 @@ -# from torchao.quantization.quant_primitives import dynamically_quantize_per_channel -from torchao.prototype.common.bitpacking import pack, unpack -from torchao.dtypes.uint4 import unpack_uint4, pack_uint4 from math import log -# from torchao.utils import benchmark_utils import torch +from torchao.prototype.common.bitpacking import pack, unpack +from torchao.dtypes.uint4 import unpack_uint4, pack_uint4 + -def benchmark(setup, function, num_runs): +def benchmark(function, num_runs, setup =None): args = setup() torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) @@ -20,7 +19,8 @@ def benchmark(setup, function, num_runs): torch.cuda.synchronize() return start_event.elapsed_time(end_event) / num_runs -def test_existing(): + +def test_vs_existing(): def new_(): fake_tensor = torch.randint(0, 2**8, (1, 1024,1024), dtype=torch.uint8).cuda() packed = pack(fake_tensor, 4, dim=1) @@ -37,29 +37,28 @@ def old_(): print(f"old: {benchmark(old_, 1000)} ms") - -def load4x(scale=1024): +def test_iso_bitpack(): + def load4x(scale=1024): fake_tensor = torch.randint(0, 2**8, (1, 4*scale,scale), dtype=torch.uint8).cuda() -def load2x(scale=1024): - fake_tensor = torch.randint(0, 2**8, (1, 2*scale,scale), dtype=torch.uint8).cuda() + def load2x(scale=1024): + fake_tensor = torch.randint(0, 2**8, (1, 2*scale,scale), dtype=torch.uint8).cuda() -def loadx(scale=1024): - fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda() - -def unpack8to2(scale=1024): - fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda() - unpacked_tensor = unpack_c(fake_tensor, 2, dim=1) - -def unpack8to4(scale=1024): - fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda() - unpacked_tensor = unpack_c(fake_tensor, 4, dim=1) + def loadx(scale=1024): + fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda() -def t8to4wmm(scale=1024): - fake_tensor = torch.randint(0, 2**8, (8, 1024,1024), dtype=torch.uint8).cuda() - unpacked_tensor = unpack_c(fake_tensor, 4, dim=1) - -def test_iso_bitpack(): + def unpack8to2(scale=1024): + fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda() + unpacked_tensor = unpack_c(fake_tensor, 2, dim=1) + + def unpack8to4(scale=1024): + fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda() + unpacked_tensor = unpack_c(fake_tensor, 4, dim=1) + + def t8to4wmm(scale=1024): + fake_tensor = torch.randint(0, 2**8, (8, 1024,1024), dtype=torch.uint8).cuda() + unpacked_tensor = unpack_c(fake_tensor, 4, dim=1) + torch._dynamo.config.specialize_int = True # _unpack_c = torch.compile(_unpack, fullgraph=True) unpack_c = torch.compile(unpack, fullgraph=True) @@ -98,122 +97,86 @@ def test_iso_bitpack(): # plt.legend() # plt.savefig("benchmark_bitpacking.png") -import hqq -import hqq.core.quantize as hqq_quantize -HQQLinear = hqq_quantize.HQQLinear -BaseQuantizeConfig = hqq_quantize.BaseQuantizeConfig - -import itertools -from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm - - -BASE_QUANT_CONFIG = { - "optimize": True, - "view_as_float": False, - "nbits": 4, - "bitpack": False, - "axis": 1, -} - - -def check(expected, actual, msg="", max_diff=1e-3, verbose=False): - passed = torch.allclose(expected, actual, atol=max_diff, rtol=max_diff) - if verbose: - max_err = (expected - actual).abs().max() - if not passed: - print_msg = f"{msg}:\nFailed! Max error: {max_err}" - try: - from termcolor import colored - except ImportError: - print(print_msg) - else: - print(colored(print_msg, "red", attrs=["bold"])) - - else: - print_msg = f"{msg}:\nPassed! Max error: {max_err}" - try: - from termcolor import colored - except ImportError: - print(print_msg) - else: - print(colored(print_msg, "green", attrs=["bold"])) - - return passed - -def mixed_mm( - shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8, pack_fn = True -): - qcfg = { - **BASE_QUANT_CONFIG, - **dict(group_size=group_size, axis=axis), +def test_vs_hqqpack(): + #requires hqq to be installed + import hqq + import hqq.core.quantize as hqq_quantize + HQQLinear = hqq_quantize.HQQLinear + BaseQuantizeConfig = hqq_quantize.BaseQuantizeConfig + from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm + + BASE_QUANT_CONFIG = { + "optimize": True, + "view_as_float": False, + "nbits": 4, + "bitpack": False, + "axis": 1, } - M, N, K = shape - - linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda") - - quant_config = BaseQuantizeConfig( - quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False - ) - quant_config.update({"weight_quant_params": qcfg}) - hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False) - W_q, meta = hqq_linear.W_q, hqq_linear.meta - W_q = W_q.to(dtype=quant_dtype) - W_q = ( - W_q.reshape(meta["shape"]) - if quant_config["weight_quant_params"]["bitpack"] == False - else W_q - ) - W_dq = hqq_linear.dequantize() - - scales, zeros = meta["scale"], meta["zero"] - scales = scales.reshape(N, -1) - zeros = zeros.reshape(N, -1) - if pack_fn: - packed_w = pack(W_q.T,4,dim=0,order=False) - else: - packed_w = pack_2xint4(W_q.T) - # print(W_q.T[0:5,0:5], W_q.T.shape) - # print(packed_w[0:5,0:5], W_q.T.shape) - # print(packed_w2[0:5,0:5], W_q.T.shape) - if transposed: - x = torch.randn(M, N, dtype=dtype, device="cuda") - hqq_out = x @ W_dq - - tt_out = triton_mixed_mm( - x, - packed_w, - scales.T, - zeros.T, - transposed=True, - group_size=group_size, - fp8_fast_accum=False, - kernel_type=kernel_type, + + def mixed_mm( + shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8, pack_fn = True + ): + qcfg = { + **BASE_QUANT_CONFIG, + **dict(group_size=group_size, axis=axis), + } + M, N, K = shape + + linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda") + + quant_config = BaseQuantizeConfig( + quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False ) - - else: - x = torch.randn(M, K, dtype=dtype, device="cuda") - hqq_out = x @ W_dq.T - - tt_out = triton_mixed_mm( - x, - packed_w, - scales.T, - zeros.T, - transposed=False, - group_size=group_size, - fp8_fast_accum=False, - kernel_type=kernel_type, + quant_config.update({"weight_quant_params": qcfg}) + hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False) + W_q, meta = hqq_linear.W_q, hqq_linear.meta + W_q = W_q.to(dtype=quant_dtype) + W_q = ( + W_q.reshape(meta["shape"]) + if quant_config["weight_quant_params"]["bitpack"] == False + else W_q ) - # assert check( - # hqq_out, - # tt_out, - # max_diff=1e-2 if dtype == torch.bfloat16 else 1e-3, - # verbose=True, - # ) + W_dq = hqq_linear.dequantize() + scales, zeros = meta["scale"], meta["zero"] + scales = scales.reshape(N, -1) + zeros = zeros.reshape(N, -1) + if pack_fn: + packed_w = pack(W_q.T,4,dim=0,order=False) + else: + packed_w = pack_2xint4(W_q.T) + + if transposed: + x = torch.randn(M, N, dtype=dtype, device="cuda") + hqq_out = x @ W_dq + + tt_out = triton_mixed_mm( + x, + packed_w, + scales.T, + zeros.T, + transposed=True, + group_size=group_size, + fp8_fast_accum=False, + kernel_type=kernel_type, + ) -def test_vs_hqqpack(): + else: + x = torch.randn(M, K, dtype=dtype, device="cuda") + hqq_out = x @ W_dq.T + + tt_out = triton_mixed_mm( + x, + packed_w, + scales.T, + zeros.T, + transposed=False, + group_size=group_size, + fp8_fast_accum=False, + kernel_type=kernel_type, + ) + shapes = [ [16, 128, 128], [16, 4096, 4096], @@ -257,7 +220,8 @@ def test_vs_hqqpack(): torch.uint8, pack_fn=False)) print("") - + + if __name__ == "__main__": - test_existing() + test_vs_existing()