Skip to content

Minor upgrades to bit pack #347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 58 additions & 192 deletions benchmarks/benchmark_bitpacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from torchao.dtypes.uint4 import unpack_uint4, pack_uint4


def benchmark(function, num_runs, setup =None):
args = setup()
def benchmark(function, args, num_runs):
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
Expand All @@ -21,207 +20,74 @@ def benchmark(function, num_runs, setup =None):


def test_vs_existing():
def new_():
fake_tensor = torch.randint(0, 2**8, (1, 1024,1024), dtype=torch.uint8).cuda()
def new_(scale):
fake_tensor = torch.randint(2**8, (1, scale,scale), dtype=torch.uint8).cuda()
packed = pack(fake_tensor, 4, dim=1)
unpacked = unpack(packed, 4, dim=1)
def old_():
fake_tensor = torch.randint(0, 2**8, (1, 1024,1024), dtype=torch.uint8).cuda()
def old_(scale):
fake_tensor = torch.randint(2**8, (1, scale,scale), dtype=torch.uint8).cuda()
packed = pack_uint4(fake_tensor)
unpacked = unpack_uint4(packed)
new_ = torch.compile(new_, fullgraph=True)
old_ = torch.compile(old_, fullgraph=True)
new_()
old_()
print(f"new: {benchmark(new_, 1000)} ms ")
print(f"old: {benchmark(old_, 1000)} ms")



def test_iso_bitpack():
def load4x(scale=1024):
fake_tensor = torch.randint(0, 2**8, (1, 4*scale,scale), dtype=torch.uint8).cuda()
for scale in [256,512, 1024, 2048,4096, 8192]:
new_ = torch.compile(new_, fullgraph=True)
old_ = torch.compile(old_, fullgraph=True)
new_(scale)
old_(scale)
print("scale: ", scale)
print(f"new: {benchmark(new_,[scale], 10)} ms ")
print(f"old: {benchmark(old_,[scale], 10)} ms")


def compare_to_fp16():
class Linear16(torch.nn.Module):
def __init__(self, scale):
super().__init__()
scale += scale % 2
self.l1 = torch.nn.Linear(scale * 2, scale, bias=False,dtype=torch.float16).cuda()
self.l2 = torch.nn.Linear(scale, scale//2, bias=False,dtype=torch.float16).cuda()

def forward(self, x):
return self.l2(self.l1(x))

def load2x(scale=1024):
fake_tensor = torch.randint(0, 2**8, (1, 2*scale,scale), dtype=torch.uint8).cuda()

def loadx(scale=1024):
fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda()
class W4A16_symmetric_weight_only(torch.nn.Module):
def __init__(self, scale):
super().__init__()
assert scale % 4 == 0
self.l1 = torch.randint(2**8,(scale, scale), dtype=torch.uint8).cuda()
self.s1 = torch.tensor((scale),dtype=torch.float16).cuda()
self.l2 = torch.randint(2**8,(scale//2, scale//4), dtype=torch.uint8).cuda()
self.s2 = torch.tensor((scale//4),dtype=torch.float16).cuda()

def unpack8to2(scale=1024):
fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda()
unpacked_tensor = unpack_c(fake_tensor, 2, dim=1)

def unpack8to4(scale=1024):
fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda()
unpacked_tensor = unpack_c(fake_tensor, 4, dim=1)

def t8to4wmm(scale=1024):
fake_tensor = torch.randint(0, 2**8, (8, 1024,1024), dtype=torch.uint8).cuda()
unpacked_tensor = unpack_c(fake_tensor, 4, dim=1)
def forward(self, x):
w = unpack(self.l1.detach(), 4, output_dtype=torch.float16)
x = x * self.s1
x = x @ w
w = unpack(self.l2.detach(), 4, output_dtype=torch.float16)
x = x * self.s2
x = x @ w

torch._dynamo.config.specialize_int = True
# _unpack_c = torch.compile(_unpack, fullgraph=True)
unpack_c = torch.compile(unpack, fullgraph=True)

scale = [16,64,256,1024,4096]
load4x_times = []
unpack8to2_times = []
load2x_times = []
unpack8to4_times = []
for s in scale:
res = benchmark(load4x, 50, scale=s)
load4x_times.append(res)
print(f"load(1, {4*s},{s}) time: {res} ms")

res=benchmark(unpack8to2, 50, scale=s)
unpack8to2_times.append(res)
print(f"load(1, {s},{s}) unpack uint2 time: {res} ms")
return x

torch._dynamo.config.specialize_int = True
for scale in [256,512, 1024, 2048,4096, 8192]:
a = Linear16(scale)
b = W4A16_symmetric_weight_only(scale)
# a = torch.compile(a, fullgraph=True)
b = torch.compile(b, fullgraph=True)

res = benchmark(load2x, 50, scale=s)
load2x_times.append(res)
print(f"load(1, {2*s},{s}) time: {res} ms")

res = benchmark(unpack8to4, 50, scale=s)
unpack8to4_times.append(res)
print(f"load(1, {s},{s}) unpack uint4 time: {res} ms")
print()

# import matplotlib.pyplot as plt
# plt.plot(scale, load4x_times, label="load(1, 4x, x)")
# plt.plot(scale, unpack8to2_times, label="unpack uint8 to uint2")
# plt.plot(scale, load2x_times, label="load(1, 2x, x)")
# plt.plot(scale, unpack8to4_times, label="unpack uint8 to uint4")
# plt.xlabel("scale")
# plt.ylabel("time (ms)")
# plt.yscale("log")
# plt.legend()
# plt.savefig("benchmark_bitpacking.png")


def test_vs_hqqpack():
#requires hqq to be installed
import hqq
import hqq.core.quantize as hqq_quantize
HQQLinear = hqq_quantize.HQQLinear
BaseQuantizeConfig = hqq_quantize.BaseQuantizeConfig
from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm

BASE_QUANT_CONFIG = {
"optimize": True,
"view_as_float": False,
"nbits": 4,
"bitpack": False,
"axis": 1,
}
test_input = torch.randn(scale*2, dtype=torch.float16).cuda()
forward_args = [test_input]
b.forward(test_input)
print("scale: ", scale)
print("fp16 time: ", benchmark(a.forward, forward_args, 100))
print("uint4 time: ", benchmark(b.forward, forward_args, 100))

def mixed_mm(
shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8, pack_fn = True
):
qcfg = {
**BASE_QUANT_CONFIG,
**dict(group_size=group_size, axis=axis),
}
M, N, K = shape

linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda")

quant_config = BaseQuantizeConfig(
quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False
)
quant_config.update({"weight_quant_params": qcfg})
hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False)
W_q, meta = hqq_linear.W_q, hqq_linear.meta
W_q = W_q.to(dtype=quant_dtype)
W_q = (
W_q.reshape(meta["shape"])
if quant_config["weight_quant_params"]["bitpack"] == False
else W_q
)
W_dq = hqq_linear.dequantize()

scales, zeros = meta["scale"], meta["zero"]
scales = scales.reshape(N, -1)
zeros = zeros.reshape(N, -1)
if pack_fn:
packed_w = pack(W_q.T,4,dim=0,order=False)
else:
packed_w = pack_2xint4(W_q.T)

if transposed:
x = torch.randn(M, N, dtype=dtype, device="cuda")
hqq_out = x @ W_dq

tt_out = triton_mixed_mm(
x,
packed_w,
scales.T,
zeros.T,
transposed=True,
group_size=group_size,
fp8_fast_accum=False,
kernel_type=kernel_type,
)

else:
x = torch.randn(M, K, dtype=dtype, device="cuda")
hqq_out = x @ W_dq.T

tt_out = triton_mixed_mm(
x,
packed_w,
scales.T,
zeros.T,
transposed=False,
group_size=group_size,
fp8_fast_accum=False,
kernel_type=kernel_type,
)

shapes = [
[16, 128, 128],
[16, 4096, 4096],
]
group_sizes = [64, 128]
shape = [16, 128, 128]
group_size = 64
pack = torch.compile(pack, fullgraph=True)
for i in range(2):
shape = shapes[i]
group_size = group_sizes[i]
print("linear layer size: ", shape)
print("group size: ", group_size)
# run once to compile
test_mixed_mm(
shape,
group_size,
1,
torch.float16,
True,
"compute_bound",
torch.uint8,
)
# shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8
print("pack time (ms): ", benchmark(test_mixed_mm, 100,
shape,
group_size,
1,
torch.float16,
True,
"compute_bound",
torch.uint8))

print("pack_2xint4 time (ms): ", benchmark(test_mixed_mm, 100,
shape,
group_size,
1,
torch.float16,
True,
"compute_bound", #max autotune doesnt work?
torch.uint8,
pack_fn=False))
print("")



if __name__ == "__main__":
compare_to_fp16()
test_vs_existing()

8 changes: 3 additions & 5 deletions test/prototype/test_bitpacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

dtypes = ((2, 'trinary', 1), (2, None, 1), (3, None, 2), (4, None, 2), (5, None, 4), (6, None, 4), (7, None, 4))
dimensions = (2, 1, 0)
dimensions = (2, 1, 0, -1)
orders = (True, False)


Expand Down Expand Up @@ -41,15 +41,13 @@ def test_CPU(dtype, dim, order):
element_type=element_type,
dim = dim,
order = order,
container_dtype = torch.uint8,
device='cpu')
container_dtype = torch.uint8)
assert(packed.shape[dim] == expected_pack_size)
unpacked = unpack(packed,
element_bit_width,
element_type=element_type,
dim = dim,
order = order,
device='cpu')
order = order)
assert(unpacked.allclose(test_tensor))


Expand Down
18 changes: 11 additions & 7 deletions torchao/prototype/common/bitpacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@

def mod_shape(shape, mod, dim):
"""changes a select dimension of the input shape to mod"""
return (*shape[:dim], mod, *shape[dim+1:])
a = list(shape)
a[dim] = mod
return tuple(a)

def unpack(data: torch.Tensor,
element_bit_width: int,
element_type: Optional[str] = None,
dim: Optional[int] = 0,
order: Optional[bool] = True,
output_dtype: Optional[torch.dtype] = None,
device: Optional[str] ="cuda") -> torch.Tensor:
output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Unpacks small dtype elements from a larger dtype.

Expand All @@ -27,8 +28,10 @@ def unpack(data: torch.Tensor,
"""
container_size = torch.iinfo(data.dtype).bits
scale = container_size // element_bit_width

device = data.device

unpacked = _unpack(data, element_bit_width, container_size, scale, order, dim, device)

if element_type == "trinary":
unpacked = unpacked.to(torch.int8) - 1
elif output_dtype is not None:
Expand Down Expand Up @@ -59,8 +62,7 @@ def pack(data: torch.Tensor,
dim: Optional[int] = 0,
container_dtype: Optional[torch.dtype] = None,
pad: Optional[bool] = False,
order: Optional[bool] = True,
device: Optional[str] = "cuda") -> torch.Tensor:
order: Optional[bool] = True) -> torch.Tensor:
"""
Packs small dtype elements into a container of a larger dtype.

Expand Down Expand Up @@ -93,6 +95,8 @@ def pack(data: torch.Tensor,
if container_dtype is not None:
data = data.to(container_dtype)

device = data.device

container_size = torch.iinfo(data.dtype).bits
scale = container_size // element_bit_width

Expand All @@ -117,4 +121,4 @@ def _pack(data, container_size, element_bit_width, scale, dim, order, device) ->
else:
packed |= data[slices] << element_bit_width*i
return packed