diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index 82d3d2501d..da20b930d3 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -181,6 +181,9 @@ class TestFloat8dqRowAffineQuantizedTensorParallel( def test_tp(self, dtype): return self._test_tp(dtype) + common_utils.instantiate_parametrized_tests( + TestFloat8woAffineQuantizedTensorParallel + ) common_utils.instantiate_parametrized_tests( TestFloat8dqTensorAffineQuantizedTensorParallel ) diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py index 46b48393a3..e86b2f8e64 100644 --- a/torchao/quantization/linear_activation_quantized_tensor.py +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -147,8 +147,8 @@ def _(func, types, args, kwargs): ) input_quant_func = weight_tensor.input_quant_func original_weight_tensor = weight_tensor.original_weight_tensor - aqt = input_quant_func(input_tensor) - return func(bias, aqt, original_weight_tensor) + qtensor = input_quant_func(input_tensor, **weight_tensor.quant_kwargs) + return func(bias, qtensor, original_weight_tensor) else: # aten.mm.default assert args[0].shape[-1] == args[1].shape[0], ( @@ -161,8 +161,8 @@ def _(func, types, args, kwargs): ) input_quant_func = weight_tensor.input_quant_func original_weight_tensor = weight_tensor.original_weight_tensor - aqt = input_quant_func(input_tensor) - return func(aqt, original_weight_tensor) + qtensor = input_quant_func(input_tensor, **weight_tensor.quant_kwargs) + return func(qtensor, original_weight_tensor) @implements(aten.detach.default) @@ -203,7 +203,9 @@ def _(func, types, args, kwargs): args, kwargs, LinearActivationQuantizedTensor( - func(args[0].original_weight_tensor, *args[1:]), args[0].input_quant_func + func(args[0].original_weight_tensor, *args[1:]), + args[0].input_quant_func, + args[0].quant_kwargs, ), ) @@ -216,7 +218,9 @@ def _(func, types, args, kwargs): args, kwargs, LinearActivationQuantizedTensor( - func(args[0].original_weight_tensor, *args[1:]), args[0].input_quant_func + func(args[0].original_weight_tensor, *args[1:]), + args[0].input_quant_func, + args[0].quant_kwargs, ), )