diff --git a/test/test.py b/test/test.py index bb4c032c02..d57519ba28 100644 --- a/test/test.py +++ b/test/test.py @@ -56,7 +56,7 @@ from torchao.quantization.weight_only import ( WeightOnlyInt8QuantLinear ) - +import os torch.manual_seed(0) @@ -904,6 +904,63 @@ def test_weight_only_quant_use_mixed_mm(self): sqnr = compute_error(y_ref, y_wo) self.assertGreater(sqnr, 43.0) +class TestSaveLoadMeta(unittest.TestCase): + @torch.no_grad() + def _test_handle_save_load_meta_impl(self, api): + m, k, n = 32, 64, 32 + class test_model(nn.Module): + def __init__(self): + super().__init__() + self.lin1 = nn.Linear(k, n) + self.relu = nn.ReLU() + self.lin2 = nn.Linear(n, n) + + def forward(self, x): + x = self.lin1(x) + x = self.relu(x) + x = self.lin2(x) + return x + + x = torch.randn(m, k, dtype=torch.bfloat16, device="cuda") + + # get float reference + model = test_model().to(torch.bfloat16).cuda().eval() + ref_f = model(x) + + # save quantized state_dict + api(model) + torch.save(model.state_dict(), "test.pth") + # get quantized reference + model_qc = torch.compile(model, mode="max-autotune") + ref_q = model_qc(x).detach() + + assert SQNR(ref_f, ref_q) > 35 + + # load model structure + with torch.device('meta'): + model = test_model() + api(model) + + # load quantized state_dict + state_dict = torch.load("test.pth", mmap=True) + os.remove("test.pth") + model.load_state_dict(state_dict, assign=True) + model = model.to(torch.bfloat16).cuda().eval() + + # get quantized reference + model_qc = torch.compile(model, mode="max-autotune") + test = model_qc(x).detach() + + assert SQNR(ref_f, test) > 35 + self.assertTrue(torch.equal(ref_q, test)) + + @torch.no_grad() + def test_save_load_dqtensors(self): + self._test_handle_save_load_meta_impl(change_linear_weights_to_dqtensors) + + @torch.no_grad() + def test_save_load_woqtensors(self): + self._test_handle_save_load_meta_impl(change_linear_weights_to_woqtensors) class TorchCompileUnitTest(unittest.TestCase): def test_fullgraph(self): diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index a914da5741..3649b8e029 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -35,7 +35,7 @@ def __new__(cls, int_data, q_scales, transposed=False, **kwargs): # transposed/detached, instead we can just pass the int_data to the # new instance and alter the transposed flag where needed. kwargs["device"] = int_data.device - kwargs["dtype"] = kwargs.get("dtype", torch.int8) + kwargs["dtype"] = kwargs.get("dtype", q_scales.dtype) size = int_data.shape[::-1] if transposed else int_data.shape kwargs["layout"] = ( kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout