Skip to content

Make developer experience better for extending AQT #749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def __repr__(self):
# Tensor Subclass Definition #
##############################


class QuantizedLinearNotImplementedError(NotImplementedError):
""" Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table """
pass


_QLINEAR_DISPATCH_TABLE = {}
def _register_quantized_linear_dispatch(dispatch_condition, impl):
_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl
Expand Down Expand Up @@ -158,8 +164,7 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias):
for dispatch_condition, impl in _QLINEAR_DISPATCH_TABLE.items():
if dispatch_condition(input_tensor, weight_tensor, bias):
return impl(input_tensor, weight_tensor, bias)

raise NotImplementedError("No specialized dispatch found for quantized linear op")
raise QuantizedLinearNotImplementedError("No specialized dispatch found for quantized linear op")

def __tensor_flatten__(self):
return ["layout_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]
Expand Down Expand Up @@ -887,7 +892,7 @@ def _(func, types, args, kwargs):
# make the branches easier to understand in `_quantized_linear_op`
try:
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
except:
except QuantizedLinearNotImplementedError:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
Expand All @@ -910,7 +915,7 @@ def _(func, types, args, kwargs):
try:
weight_tensor = weight_tensor.t()
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
except:
except QuantizedLinearNotImplementedError:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
Expand All @@ -930,7 +935,7 @@ def _(func, types, args, kwargs):
try:
weight_tensor = weight_tensor.t()
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
except:
except QuantizedLinearNotImplementedError:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
Expand Down
1 change: 1 addition & 0 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ class AQWeightOnlyQuantizedLinearWeight3(AQWeightOnlyQuantizedLinearWeight, AQMi
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
uses a different kernel
"""
@staticmethod
def _quantized_linear_op(act_mat, w_qtensor, bias):
orig_shape = act_mat.shape
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.layout_tensor.int_data.t()*w_qtensor.layout_tensor.scale)
Expand Down
Loading