Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions test/prototype/test_quant_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchao.prototype.quant_llm import (
QuantLlmLinearWeight,
quant_llm_fpx_weight_only,
fp6_llm_weight_only,
to_scaled_tc_fpx,
from_scaled_tc_fpx,
)
Expand Down Expand Up @@ -65,6 +66,15 @@ def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device):
actual = torch.compile(from_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits, scale)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", _FPx_DTYPES)
def test_to_copy_device(self, ebits, mbits):
x = torch.randn(256, 64)
fpx = QuantLlmLinearWeight.from_float(x, ebits, mbits).cuda()
assert fpx.device.type == "cuda"
fpx = fpx.cpu()
assert fpx.device.type == "cpu"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("leading_dims", [(4,), (2, 4)])
Expand Down Expand Up @@ -98,6 +108,20 @@ def test_quant_llm_quantize(self, ebits, mbits, bias):
actual = torch.compile(fpx_linear, fullgraph=True)(x)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_fp6_llm_quantize(self):
N, OC, IC = 4, 256, 64
device = "cuda"

linear = torch.nn.Linear(IC, OC, device=device)
fpx_linear = copy.deepcopy(linear)
quantize_(fpx_linear, fp6_llm_weight_only())

x = torch.randn(N, IC, device=device, dtype=torch.half)
expected = fpx_linear(x)
actual = torch.compile(fpx_linear, fullgraph=True)(x)
torch.testing.assert_close(actual, expected)


instantiate_parametrized_tests(TestQuantLlmLinearWeight)

Expand Down
21 changes: 19 additions & 2 deletions torchao/prototype/quant_llm/quant_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torchao.quantization.quant_api import _get_linear_subclass_inserter


aten = torch.ops.aten
_ONES_TABLE = [_n_ones(i) for i in range(8)]


Expand Down Expand Up @@ -430,11 +431,27 @@ def _(func, types, args, kwargs):
return out.view(*act.shape[:-1], out_dim).to(act.dtype)


@QuantLlmLinearWeight.implements(torch.ops.aten.detach.default)
@QuantLlmLinearWeight.implements(aten.detach.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))


@QuantLlmLinearWeight.implements(aten.clone.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.clone))


@QuantLlmLinearWeight.implements(aten._to_copy.default)
def _(func, types, args, kwargs):
# only support device kwargs, ignore the rest
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0]._apply_fn_to_data(lambda x: x.to(device=kwargs.pop("device", None))),
)


def quant_llm_fpx_weight_only(ebits: int, mbits: int):
def apply_quant_llm(weight: Tensor) -> Tensor:
out_dim, in_dim = weight.shape
Expand All @@ -445,4 +462,4 @@ def apply_quant_llm(weight: Tensor) -> Tensor:


def fp6_llm_weight_only():
return _get_linear_subclass_inserter(quant_llm_fpx_weight_only(3, 2))
return quant_llm_fpx_weight_only(3, 2)
Loading