diff --git a/benchmarks/benchmark_bitpacking.py b/benchmarks/benchmark_bitpacking.py index e974efca58..9e3d57d508 100644 --- a/benchmarks/benchmark_bitpacking.py +++ b/benchmarks/benchmark_bitpacking.py @@ -5,8 +5,7 @@ from torchao.dtypes.uint4 import unpack_uint4, pack_uint4 -def benchmark(function, num_runs, setup =None): - args = setup() +def benchmark(function, args, num_runs): torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) @@ -21,207 +20,74 @@ def benchmark(function, num_runs, setup =None): def test_vs_existing(): - def new_(): - fake_tensor = torch.randint(0, 2**8, (1, 1024,1024), dtype=torch.uint8).cuda() + def new_(scale): + fake_tensor = torch.randint(2**8, (1, scale,scale), dtype=torch.uint8).cuda() packed = pack(fake_tensor, 4, dim=1) unpacked = unpack(packed, 4, dim=1) - def old_(): - fake_tensor = torch.randint(0, 2**8, (1, 1024,1024), dtype=torch.uint8).cuda() + def old_(scale): + fake_tensor = torch.randint(2**8, (1, scale,scale), dtype=torch.uint8).cuda() packed = pack_uint4(fake_tensor) unpacked = unpack_uint4(packed) - new_ = torch.compile(new_, fullgraph=True) - old_ = torch.compile(old_, fullgraph=True) - new_() - old_() - print(f"new: {benchmark(new_, 1000)} ms ") - print(f"old: {benchmark(old_, 1000)} ms") - + -def test_iso_bitpack(): - def load4x(scale=1024): - fake_tensor = torch.randint(0, 2**8, (1, 4*scale,scale), dtype=torch.uint8).cuda() + for scale in [256,512, 1024, 2048,4096, 8192]: + new_ = torch.compile(new_, fullgraph=True) + old_ = torch.compile(old_, fullgraph=True) + new_(scale) + old_(scale) + print("scale: ", scale) + print(f"new: {benchmark(new_,[scale], 10)} ms ") + print(f"old: {benchmark(old_,[scale], 10)} ms") + + +def compare_to_fp16(): + class Linear16(torch.nn.Module): + def __init__(self, scale): + super().__init__() + scale += scale % 2 + self.l1 = torch.nn.Linear(scale * 2, scale, bias=False,dtype=torch.float16).cuda() + self.l2 = torch.nn.Linear(scale, scale//2, bias=False,dtype=torch.float16).cuda() + + def forward(self, x): + return self.l2(self.l1(x)) - def load2x(scale=1024): - fake_tensor = torch.randint(0, 2**8, (1, 2*scale,scale), dtype=torch.uint8).cuda() - - def loadx(scale=1024): - fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda() + class W4A16_symmetric_weight_only(torch.nn.Module): + def __init__(self, scale): + super().__init__() + assert scale % 4 == 0 + self.l1 = torch.randint(2**8,(scale, scale), dtype=torch.uint8).cuda() + self.s1 = torch.tensor((scale),dtype=torch.float16).cuda() + self.l2 = torch.randint(2**8,(scale//2, scale//4), dtype=torch.uint8).cuda() + self.s2 = torch.tensor((scale//4),dtype=torch.float16).cuda() - def unpack8to2(scale=1024): - fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda() - unpacked_tensor = unpack_c(fake_tensor, 2, dim=1) - def unpack8to4(scale=1024): - fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda() - unpacked_tensor = unpack_c(fake_tensor, 4, dim=1) - - def t8to4wmm(scale=1024): - fake_tensor = torch.randint(0, 2**8, (8, 1024,1024), dtype=torch.uint8).cuda() - unpacked_tensor = unpack_c(fake_tensor, 4, dim=1) + def forward(self, x): + w = unpack(self.l1.detach(), 4, output_dtype=torch.float16) + x = x * self.s1 + x = x @ w + w = unpack(self.l2.detach(), 4, output_dtype=torch.float16) + x = x * self.s2 + x = x @ w - torch._dynamo.config.specialize_int = True - # _unpack_c = torch.compile(_unpack, fullgraph=True) - unpack_c = torch.compile(unpack, fullgraph=True) - - scale = [16,64,256,1024,4096] - load4x_times = [] - unpack8to2_times = [] - load2x_times = [] - unpack8to4_times = [] - for s in scale: - res = benchmark(load4x, 50, scale=s) - load4x_times.append(res) - print(f"load(1, {4*s},{s}) time: {res} ms") - - res=benchmark(unpack8to2, 50, scale=s) - unpack8to2_times.append(res) - print(f"load(1, {s},{s}) unpack uint2 time: {res} ms") + return x + + torch._dynamo.config.specialize_int = True + for scale in [256,512, 1024, 2048,4096, 8192]: + a = Linear16(scale) + b = W4A16_symmetric_weight_only(scale) + # a = torch.compile(a, fullgraph=True) + b = torch.compile(b, fullgraph=True) - res = benchmark(load2x, 50, scale=s) - load2x_times.append(res) - print(f"load(1, {2*s},{s}) time: {res} ms") - - res = benchmark(unpack8to4, 50, scale=s) - unpack8to4_times.append(res) - print(f"load(1, {s},{s}) unpack uint4 time: {res} ms") - print() - - # import matplotlib.pyplot as plt - # plt.plot(scale, load4x_times, label="load(1, 4x, x)") - # plt.plot(scale, unpack8to2_times, label="unpack uint8 to uint2") - # plt.plot(scale, load2x_times, label="load(1, 2x, x)") - # plt.plot(scale, unpack8to4_times, label="unpack uint8 to uint4") - # plt.xlabel("scale") - # plt.ylabel("time (ms)") - # plt.yscale("log") - # plt.legend() - # plt.savefig("benchmark_bitpacking.png") - - -def test_vs_hqqpack(): - #requires hqq to be installed - import hqq - import hqq.core.quantize as hqq_quantize - HQQLinear = hqq_quantize.HQQLinear - BaseQuantizeConfig = hqq_quantize.BaseQuantizeConfig - from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm - - BASE_QUANT_CONFIG = { - "optimize": True, - "view_as_float": False, - "nbits": 4, - "bitpack": False, - "axis": 1, - } + test_input = torch.randn(scale*2, dtype=torch.float16).cuda() + forward_args = [test_input] + b.forward(test_input) + print("scale: ", scale) + print("fp16 time: ", benchmark(a.forward, forward_args, 100)) + print("uint4 time: ", benchmark(b.forward, forward_args, 100)) - def mixed_mm( - shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8, pack_fn = True - ): - qcfg = { - **BASE_QUANT_CONFIG, - **dict(group_size=group_size, axis=axis), - } - M, N, K = shape - - linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda") - - quant_config = BaseQuantizeConfig( - quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False - ) - quant_config.update({"weight_quant_params": qcfg}) - hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False) - W_q, meta = hqq_linear.W_q, hqq_linear.meta - W_q = W_q.to(dtype=quant_dtype) - W_q = ( - W_q.reshape(meta["shape"]) - if quant_config["weight_quant_params"]["bitpack"] == False - else W_q - ) - W_dq = hqq_linear.dequantize() - - scales, zeros = meta["scale"], meta["zero"] - scales = scales.reshape(N, -1) - zeros = zeros.reshape(N, -1) - if pack_fn: - packed_w = pack(W_q.T,4,dim=0,order=False) - else: - packed_w = pack_2xint4(W_q.T) - - if transposed: - x = torch.randn(M, N, dtype=dtype, device="cuda") - hqq_out = x @ W_dq - - tt_out = triton_mixed_mm( - x, - packed_w, - scales.T, - zeros.T, - transposed=True, - group_size=group_size, - fp8_fast_accum=False, - kernel_type=kernel_type, - ) - - else: - x = torch.randn(M, K, dtype=dtype, device="cuda") - hqq_out = x @ W_dq.T - - tt_out = triton_mixed_mm( - x, - packed_w, - scales.T, - zeros.T, - transposed=False, - group_size=group_size, - fp8_fast_accum=False, - kernel_type=kernel_type, - ) - - shapes = [ - [16, 128, 128], - [16, 4096, 4096], - ] - group_sizes = [64, 128] - shape = [16, 128, 128] - group_size = 64 - pack = torch.compile(pack, fullgraph=True) - for i in range(2): - shape = shapes[i] - group_size = group_sizes[i] - print("linear layer size: ", shape) - print("group size: ", group_size) - # run once to compile - test_mixed_mm( - shape, - group_size, - 1, - torch.float16, - True, - "compute_bound", - torch.uint8, - ) - # shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8 - print("pack time (ms): ", benchmark(test_mixed_mm, 100, - shape, - group_size, - 1, - torch.float16, - True, - "compute_bound", - torch.uint8)) - print("pack_2xint4 time (ms): ", benchmark(test_mixed_mm, 100, - shape, - group_size, - 1, - torch.float16, - True, - "compute_bound", #max autotune doesnt work? - torch.uint8, - pack_fn=False)) - print("") - - + if __name__ == "__main__": + compare_to_fp16() test_vs_existing() - + \ No newline at end of file diff --git a/test/prototype/test_bitpacking.py b/test/prototype/test_bitpacking.py index 7facc97dc9..52413bf4bf 100644 --- a/test/prototype/test_bitpacking.py +++ b/test/prototype/test_bitpacking.py @@ -8,7 +8,7 @@ pytest.skip("Unsupported PyTorch version", allow_module_level=True) dtypes = ((2, 'trinary', 1), (2, None, 1), (3, None, 2), (4, None, 2), (5, None, 4), (6, None, 4), (7, None, 4)) -dimensions = (2, 1, 0) +dimensions = (2, 1, 0, -1) orders = (True, False) @@ -41,15 +41,13 @@ def test_CPU(dtype, dim, order): element_type=element_type, dim = dim, order = order, - container_dtype = torch.uint8, - device='cpu') + container_dtype = torch.uint8) assert(packed.shape[dim] == expected_pack_size) unpacked = unpack(packed, element_bit_width, element_type=element_type, dim = dim, - order = order, - device='cpu') + order = order) assert(unpacked.allclose(test_tensor)) diff --git a/torchao/prototype/common/bitpacking.py b/torchao/prototype/common/bitpacking.py index 60009d0e63..867fecdda6 100644 --- a/torchao/prototype/common/bitpacking.py +++ b/torchao/prototype/common/bitpacking.py @@ -3,15 +3,16 @@ def mod_shape(shape, mod, dim): """changes a select dimension of the input shape to mod""" - return (*shape[:dim], mod, *shape[dim+1:]) + a = list(shape) + a[dim] = mod + return tuple(a) def unpack(data: torch.Tensor, element_bit_width: int, element_type: Optional[str] = None, dim: Optional[int] = 0, order: Optional[bool] = True, - output_dtype: Optional[torch.dtype] = None, - device: Optional[str] ="cuda") -> torch.Tensor: + output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ Unpacks small dtype elements from a larger dtype. @@ -27,8 +28,10 @@ def unpack(data: torch.Tensor, """ container_size = torch.iinfo(data.dtype).bits scale = container_size // element_bit_width - + device = data.device + unpacked = _unpack(data, element_bit_width, container_size, scale, order, dim, device) + if element_type == "trinary": unpacked = unpacked.to(torch.int8) - 1 elif output_dtype is not None: @@ -59,8 +62,7 @@ def pack(data: torch.Tensor, dim: Optional[int] = 0, container_dtype: Optional[torch.dtype] = None, pad: Optional[bool] = False, - order: Optional[bool] = True, - device: Optional[str] = "cuda") -> torch.Tensor: + order: Optional[bool] = True) -> torch.Tensor: """ Packs small dtype elements into a container of a larger dtype. @@ -93,6 +95,8 @@ def pack(data: torch.Tensor, if container_dtype is not None: data = data.to(container_dtype) + device = data.device + container_size = torch.iinfo(data.dtype).bits scale = container_size // element_bit_width @@ -117,4 +121,4 @@ def _pack(data, container_size, element_bit_width, scale, dim, order, device) -> else: packed |= data[slices] << element_bit_width*i return packed - + \ No newline at end of file