diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 4e5dcaa815..6c36d98c4c 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -70,6 +70,12 @@ def __repr__(self): # Tensor Subclass Definition # ############################## + +class QuantizedLinearNotImplementedError(NotImplementedError): + """ Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table """ + pass + + _QLINEAR_DISPATCH_TABLE = {} def _register_quantized_linear_dispatch(dispatch_condition, impl): _QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl @@ -158,8 +164,7 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias): for dispatch_condition, impl in _QLINEAR_DISPATCH_TABLE.items(): if dispatch_condition(input_tensor, weight_tensor, bias): return impl(input_tensor, weight_tensor, bias) - - raise NotImplementedError("No specialized dispatch found for quantized linear op") + raise QuantizedLinearNotImplementedError("No specialized dispatch found for quantized linear op") def __tensor_flatten__(self): return ["layout_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] @@ -887,7 +892,7 @@ def _(func, types, args, kwargs): # make the branches easier to understand in `_quantized_linear_op` try: return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - except: + except QuantizedLinearNotImplementedError: if isinstance(input_tensor, AffineQuantizedTensor): input_tensor = input_tensor.dequantize() if isinstance(weight_tensor, AffineQuantizedTensor): @@ -910,7 +915,7 @@ def _(func, types, args, kwargs): try: weight_tensor = weight_tensor.t() return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - except: + except QuantizedLinearNotImplementedError: if isinstance(input_tensor, AffineQuantizedTensor): input_tensor = input_tensor.dequantize() if isinstance(weight_tensor, AffineQuantizedTensor): @@ -930,7 +935,7 @@ def _(func, types, args, kwargs): try: weight_tensor = weight_tensor.t() return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - except: + except QuantizedLinearNotImplementedError: if isinstance(input_tensor, AffineQuantizedTensor): input_tensor = input_tensor.dequantize() if isinstance(weight_tensor, AffineQuantizedTensor): diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index dd6d319931..cc51dd5ced 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -413,6 +413,7 @@ class AQWeightOnlyQuantizedLinearWeight3(AQWeightOnlyQuantizedLinearWeight, AQMi AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that uses a different kernel """ + @staticmethod def _quantized_linear_op(act_mat, w_qtensor, bias): orig_shape = act_mat.shape y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.layout_tensor.int_data.t()*w_qtensor.layout_tensor.scale)