From 669360566a869030b4ffaac83d03212178091bd0 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 21 Nov 2023 16:02:00 -0800 Subject: [PATCH] Adding tests for save/load support Summary: we are able to save a model quantized with a tensor subclass, save the state dict, then later, load model as meta tensor (i.e. only load tensor metadata not actually parameters) apply quantization api, and then load the quantized model state dict. We change the dtype of the subclass to match the dtype of the dequantized form, both to align with subclass design guidelines and to make this work Test Plan: python test/test.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/test.py | 59 +++++++++++++++++++++++++++++++- torchao/quantization/subclass.py | 2 +- 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/test/test.py b/test/test.py index 917a85b3ee..c3febbcbaf 100644 --- a/test/test.py +++ b/test/test.py @@ -56,7 +56,7 @@ from torchao.quantization.weight_only import ( WeightOnlyInt8QuantLinear ) - +import os torch.manual_seed(0) @@ -932,6 +932,63 @@ def test_weight_only_quant_use_mixed_mm(self): sqnr = compute_error(y_ref, y_wo) self.assertGreater(sqnr, 43.0) +class TestSaveLoadMeta(unittest.TestCase): + @torch.no_grad() + def _test_handle_save_load_meta_impl(self, api): + m, k, n = 32, 64, 32 + class test_model(nn.Module): + def __init__(self): + super().__init__() + self.lin1 = nn.Linear(k, n) + self.relu = nn.ReLU() + self.lin2 = nn.Linear(n, n) + + def forward(self, x): + x = self.lin1(x) + x = self.relu(x) + x = self.lin2(x) + return x + + x = torch.randn(m, k, dtype=torch.bfloat16, device="cuda") + + # get float reference + model = test_model().to(torch.bfloat16).cuda().eval() + ref_f = model(x) + + # save quantized state_dict + api(model) + torch.save(model.state_dict(), "test.pth") + # get quantized reference + model_qc = torch.compile(model, mode="max-autotune") + ref_q = model_qc(x).detach() + + assert SQNR(ref_f, ref_q) > 35 + + # load model structure + with torch.device('meta'): + model = test_model() + api(model) + + # load quantized state_dict + state_dict = torch.load("test.pth", mmap=True) + os.remove("test.pth") + model.load_state_dict(state_dict, assign=True) + model = model.to(torch.bfloat16).cuda().eval() + + # get quantized reference + model_qc = torch.compile(model, mode="max-autotune") + test = model_qc(x).detach() + + assert SQNR(ref_f, test) > 35 + self.assertTrue(torch.equal(ref_q, test)) + + @torch.no_grad() + def test_save_load_dqtensors(self): + self._test_handle_save_load_meta_impl(change_linear_weights_to_dqtensors) + + @torch.no_grad() + def test_save_load_woqtensors(self): + self._test_handle_save_load_meta_impl(change_linear_weights_to_woqtensors) class TorchCompileUnitTest(unittest.TestCase): def test_fullgraph(self): diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index aeef14a299..4f2cad6729 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -32,7 +32,7 @@ def __new__(cls, int_data, q_scales, transposed=False, **kwargs): # transposed/detached, instead we can just pass the int_data to the # new instance and alter the transposed flag where needed. kwargs["device"] = int_data.device - kwargs["dtype"] = kwargs.get("dtype", torch.int8) + kwargs["dtype"] = kwargs.get("dtype", q_scales.dtype) size = int_data.shape[::-1] if transposed else int_data.shape kwargs["layout"] = ( kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout