diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 5f34b761cd..3bc8ded793 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -61,7 +61,8 @@ AQInt8DynamicallyQuantizedLinearWeight, AQWeightOnlyQuantizedLinearWeight, AQWeightOnlyQuantizedLinearWeight2, - AQWeightOnlyQuantizedLinearWeight3 + AQWeightOnlyQuantizedLinearWeight3, + AutoQuantizableLinearWeight, ) from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx @@ -1471,6 +1472,44 @@ def forward(self, x, y): sqnr = SQNR(out, out2) self.assertTrue(sqnr >= 30) + @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, + [ + (16, 128, 128), + ])) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") + def test_autoquant_double_access(self, device, dtype, m, k, n): + if device != "cuda" and dtype != torch.bfloat16: + self.skipTest(f"autoquant currently does not support {device}") + if device != "cuda" or not torch.cuda.is_available(): + self.skipTest(f"autoquant currently does not support {device}") + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): + if dtype == torch.bfloat16: + self.skipTest(f"bfloat16 requires sm80+") + + class DoubleAccess(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin1 = torch.nn.Linear(k, n) + self.lin2 = torch.nn.Linear(n, k) + self.lin3 = torch.nn.Linear(k, n) + self.lin3.weight = self.lin1.weight + + def forward(self, x): + x = self.lin1(x) + x = self.lin2(x) + x = self.lin3(x) + return x + + x_in = torch.randn(m, k, device=device, dtype=dtype) + model = DoubleAccess().to(device).to(dtype) + model(x_in) + torchao.autoquant(model) + assert not isinstance(model.lin1.weight.weight, AutoQuantizableLinearWeight) + model(x_in) + + + + class TestAOTI(unittest.TestCase): @parameterized.expand( list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 4c0ae53ce8..808f7d89d3 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -252,7 +252,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): ) q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs") with torch.no_grad(): - res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data) + res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_qtensor.int_data) print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms") # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op @@ -384,7 +384,7 @@ def change_autoquantizable_to_quantized(model, **kwargs): torch._dynamo.reset() @torch.no_grad() -def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["relu",None], **aq_kwargs): +def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["interpolate", .85], **aq_kwargs): """ wraps model in AutoQuantWrapper, if example_input is provided, runs forward on it, otherwise returns the wrapped model. AutoQuantWrapper handles instances where model is torch.compiled by first performing autoquantization on the original diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a5a3a2b3db..20c52aa3f0 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -34,7 +34,7 @@ Int4WeightOnlyGPTQQuantizer, Int4WeightOnlyQuantizer, ) -from .autoquant import autoquant +from .autoquant import autoquant, AutoQuantizableLinearWeight __all__ = [ @@ -91,6 +91,7 @@ def _is_linear(mod, *args): isinstance(mod, torch.nn.Linear) and hasattr(mod, "weight") and not isinstance(mod.weight, QuantizedLinearWeightBase) + and not isinstance(mod.weight, AutoQuantizableLinearWeight) )