Skip to content

Minor fix for logical operators precedence in _aqt_is_* checks. #899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,24 +1119,24 @@ def _aqt_is_int8(aqt):
"""Check if an AffineQuantizedTensor is int8 quantized Tensor"""
return (
aqt.layout_tensor.dtype == torch.int8 and
aqt.quant_min is None or aqt.quant_min == -128 and
aqt.quant_max is None or aqt.quant_max == 127
(aqt.quant_min is None or aqt.quant_min == -128) and
(aqt.quant_max is None or aqt.quant_max == 127)
)

def _aqt_is_int8_reduced_range(aqt):
return (
aqt.layout_tensor.dtype == torch.int8 and
aqt.quant_min == -127 and
aqt.quant_max is None or aqt.quant_max == 127
(aqt.quant_max is None or aqt.quant_max == 127)
)

def _aqt_is_uint4(aqt):
def _aqt_is_tensor_core_tile_uint4(aqt):
"""Check if an AffineQuantizedTensor is uint4 quantized Tensor"""
# TODO: use torch.uint4
return (
aqt.layout_tensor.dtype == torch.int32 and
aqt.quant_min is None or aqt.quant_min == 0 and
aqt.quant_max is None or aqt.quant_max == 15
aqt.quant_min == 0 and
aqt.quant_max == 15
)


Expand Down Expand Up @@ -1228,7 +1228,7 @@ def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias):
input_tensor.dtype == torch.bfloat16 and
# weight is uint4, group quantized tensor_core_tiled layout affine quantized tensor
isinstance(weight_tensor, AffineQuantizedTensor) and
_aqt_is_uint4(weight_tensor) and
_aqt_is_tensor_core_tile_uint4(weight_tensor) and
weight_tensor.dtype == torch.bfloat16 and
len(weight_tensor.shape) == 2 and
weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT and
Expand Down Expand Up @@ -1429,7 +1429,7 @@ def _linear_fp_act_fp8_weight_impl(
def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias):
return (
isinstance(weight_tensor, AffineQuantizedTensor) and
_aqt_is_uint4(weight_tensor) and
_aqt_is_tensor_core_tile_uint4(weight_tensor) and
input_tensor.dtype == torch.float16 and
len(weight_tensor.shape) == 2 and
weight_tensor.zero_point_domain == ZeroPointDomain.INT and
Expand Down
Loading