diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 4bd65edc3b..5d508e1281 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -19,6 +19,7 @@ from torchao.quantization.dynamic_quant import ( DynamicallyPerAxisQuantizedLinear, ) +from torchao.dtypes import TensorCoreTiledLayoutType from torchao.quantization.quant_api import ( int4_weight_only, int8_weight_only, @@ -852,18 +853,20 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])): for groupsize in [64, 32]: for inner_k_tiles in [4, 2]: - kwargs = {"groupsize": groupsize, "inner_k_tiles": inner_k_tiles} + kwargs = {"groupsize": groupsize, "layout_type": TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)} def api(mod): + kwargs_copy = kwargs.copy() if TORCH_VERSION_AFTER_2_4: - kwargs_copy = kwargs.copy() kwargs_copy["group_size"] = groupsize del kwargs_copy["groupsize"] quantize_(mod, int4_weight_only(**kwargs_copy)) if not TORCH_VERSION_AFTER_2_5: unwrap_tensor_subclass(mod) else: - change_linear_weights_to_int4_woqtensors(mod, **kwargs) + kwargs_copy["inner_k_tiles"] = inner_k_tiles + del kwargs_copy["layout_type"] + change_linear_weights_to_int4_woqtensors(mod, **kwargs_copy) self._test_lin_weight_subclass_api_impl( api, diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 3a329989a4..36b5440de3 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -21,7 +21,14 @@ import torch.nn.functional as F from typing import Any, Callable, Union, Dict, Optional -from torchao.dtypes import PlainLayoutType +from torchao.dtypes.uintx.Uintx import UintxLayoutType +from torchao.dtypes import ( + to_affine_quantized, + TensorCoreTiledLayoutType, + PlainLayoutType, + AffineQuantizedTensor, + SemiSparseLayoutType +) from torchao.utils import ( TORCH_VERSION_AFTER_2_4, unwrap_tensor_subclass, @@ -182,9 +189,6 @@ def _replace_with_custom_fn_if_matches_filter( def _is_linear(mod, *args): - # avoid circular dep - from torchao.dtypes import AffineQuantizedTensor - # adding weight tensor subclass isinstance check to make sure the weight is only quantized once # when it is shared by multiple linear modules return ( @@ -328,9 +332,6 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: ) def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: - # avoid circular dep - from torchao.dtypes import to_affine_quantized - mapping_type = MappingType.ASYMMETRIC target_dtype = torch.int8 return to_affine_quantized(x, mapping_type, _get_per_token_block_size(x), target_dtype) @@ -339,9 +340,6 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32): if weight.shape[-1] % group_size != 0: return weight - # avoid circular dep - from torchao.dtypes import to_affine_quantized - # weight settings mapping_type = MappingType.SYMMETRIC block_size = (1, group_size) @@ -373,7 +371,7 @@ def insert_subclass(lin): return insert_subclass -def int4_weight_only(group_size=128, inner_k_tiles=8): +def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8)): """ Applies uint4 weight-only asymmetric per-group quantization to linear layers, using "tensor_core_tiled" layout for speedup with tinygemm kernel @@ -389,16 +387,12 @@ def int4_weight_only(group_size=128, inner_k_tiles=8): Args: `group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [256, 128, 64, 32] - `inner_k_tiles`: parameter for int4 mm kernel, choices are [8, 4, 2] + `layout_type`: layout type for quantized tensor, default is `TensorCoreTiledLayoutType(inner_k_tiles=8)` """ def apply_int4_weight_only_quant(weight): if weight.shape[-1] % group_size != 0: return weight - # avoid circular dep - from torchao.dtypes import to_affine_quantized - from torchao.dtypes import TensorCoreTiledLayoutType - mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) target_dtype = torch.int32 @@ -408,7 +402,6 @@ def apply_int4_weight_only_quant(weight): preserve_zero = False zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT - layout_type = TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles) return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type) return _get_linear_subclass_inserter(apply_int4_weight_only_quant) @@ -419,9 +412,6 @@ def int8_weight_only(): Applies int8 weight-only symmetric per-channel quantization to linear layers. """ def apply_int8wo_quant(weight): - # avoid circular dep - from torchao.dtypes import to_affine_quantized - mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps @@ -432,8 +422,6 @@ def apply_int8wo_quant(weight): return _get_linear_subclass_inserter(apply_int8wo_quant) def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor: - # avoid circular dep - from torchao.dtypes import to_affine_quantized mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 eps = 1e-5 @@ -453,8 +441,6 @@ def apply_int8_dynamic_activation_int8_weight_quant(weight): if in_features <= 16: return weight - # avoid circular dep - from torchao.dtypes import to_affine_quantized # weight settings mapping_type = MappingType.SYMMETRIC def get_weight_block_size(x): @@ -479,7 +465,6 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight quantization + 2:4 sparsity to linear layers. """ - from torchao.dtypes import SemiSparseLayoutType return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()) @@ -495,8 +480,6 @@ def uintx_weight_only(bit_width, group_size=64, pack_dim=-1): quantize_affine, dequantize_affine, ) - from torchao.dtypes.uintx.Uintx import UintxLayoutType - from torchao.dtypes import to_affine_quantized from torchao.quantization.quant_api import _get_linear_subclass_inserter def apply_uintx_weight_only_quant(weight):