Skip to content

Commit 8949fb2

Browse files
authored
Merge pull request #12 from pytorch-labs/gh/HDCharles/3/head
Adding tests for save/load support
2 parents 2a9f270 + 953679e commit 8949fb2

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

test/test.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from torchao.quantization.weight_only import (
5757
WeightOnlyInt8QuantLinear
5858
)
59-
59+
import os
6060

6161
torch.manual_seed(0)
6262

@@ -904,6 +904,63 @@ def test_weight_only_quant_use_mixed_mm(self):
904904
sqnr = compute_error(y_ref, y_wo)
905905
self.assertGreater(sqnr, 43.0)
906906

907+
class TestSaveLoadMeta(unittest.TestCase):
908+
@torch.no_grad()
909+
def _test_handle_save_load_meta_impl(self, api):
910+
m, k, n = 32, 64, 32
911+
class test_model(nn.Module):
912+
def __init__(self):
913+
super().__init__()
914+
self.lin1 = nn.Linear(k, n)
915+
self.relu = nn.ReLU()
916+
self.lin2 = nn.Linear(n, n)
917+
918+
def forward(self, x):
919+
x = self.lin1(x)
920+
x = self.relu(x)
921+
x = self.lin2(x)
922+
return x
923+
924+
x = torch.randn(m, k, dtype=torch.bfloat16, device="cuda")
925+
926+
# get float reference
927+
model = test_model().to(torch.bfloat16).cuda().eval()
928+
ref_f = model(x)
929+
930+
# save quantized state_dict
931+
api(model)
932+
torch.save(model.state_dict(), "test.pth")
933+
# get quantized reference
934+
model_qc = torch.compile(model, mode="max-autotune")
935+
ref_q = model_qc(x).detach()
936+
937+
assert SQNR(ref_f, ref_q) > 35
938+
939+
# load model structure
940+
with torch.device('meta'):
941+
model = test_model()
942+
api(model)
943+
944+
# load quantized state_dict
945+
state_dict = torch.load("test.pth", mmap=True)
946+
os.remove("test.pth")
947+
model.load_state_dict(state_dict, assign=True)
948+
model = model.to(torch.bfloat16).cuda().eval()
949+
950+
# get quantized reference
951+
model_qc = torch.compile(model, mode="max-autotune")
952+
test = model_qc(x).detach()
953+
954+
assert SQNR(ref_f, test) > 35
955+
self.assertTrue(torch.equal(ref_q, test))
956+
957+
@torch.no_grad()
958+
def test_save_load_dqtensors(self):
959+
self._test_handle_save_load_meta_impl(change_linear_weights_to_dqtensors)
960+
961+
@torch.no_grad()
962+
def test_save_load_woqtensors(self):
963+
self._test_handle_save_load_meta_impl(change_linear_weights_to_woqtensors)
907964

908965
class TorchCompileUnitTest(unittest.TestCase):
909966
def test_fullgraph(self):

torchao/quantization/subclass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __new__(cls, int_data, q_scales, transposed=False, **kwargs):
3535
# transposed/detached, instead we can just pass the int_data to the
3636
# new instance and alter the transposed flag where needed.
3737
kwargs["device"] = int_data.device
38-
kwargs["dtype"] = kwargs.get("dtype", torch.int8)
38+
kwargs["dtype"] = kwargs.get("dtype", q_scales.dtype)
3939
size = int_data.shape[::-1] if transposed else int_data.shape
4040
kwargs["layout"] = (
4141
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout

0 commit comments

Comments
 (0)