Skip to content

Expose hqq through uintx_weight_only API #786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 30 additions & 48 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,19 @@
)

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_3,
)
from torchao.quantization import (
uintx_weight_only,
int4_weight_only,
)

cuda_available = torch.cuda.is_available()

#Parameters
device = 'cuda:0'
compute_dtype = torch.bfloat16
group_size = 64
group_size = 64
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size) #axis=1
preserve_zero = False
Expand All @@ -34,81 +37,60 @@

def _init_data(in_features, out_features, compute_dtype, device, torch_seed):
torch.random.manual_seed(torch_seed)
linear_layer = torch.nn.Linear(in_features, out_features, bias=False, device=device)
linear_layer = torch.nn.Linear(in_features, out_features, bias=False).to(device)
x = torch.randn((1, linear_layer.in_features), dtype=torch.float, device=device)/20.
y_ref = linear_layer(x)
W = linear_layer.weight.data.clone().to(device=device, dtype=compute_dtype)
return W, x, y_ref

def _eval_hqq(nbits, layout_type):
W, x, y_ref = _init_data(in_features, out_features, compute_dtype, device, torch_seed)

#Plain layout
target_dtype = torch.uint8
#Tensorcore layout
if isinstance(layout_type, TensorCoreTiledLayoutType):
target_dtype = torch.uint8 if TORCH_VERSION_AT_LEAST_2_5 else torch.int32

q_tensor_hqq = to_affine_quantized_intx(
input_float=W,
mapping_type=mapping_type,
block_size=block_size,
target_dtype=target_dtype,
quant_min=0,
quant_max=2**nbits - 1,
zero_point_domain=zero_point_domain,
preserve_zero=preserve_zero,
layout_type=layout_type,
use_hqq=True,
)
def _eval_hqq(dtype):
W, x, y_ref = _init_data(in_features, out_features, compute_dtype, device, torch_seed)

dummy_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, bias=False)
dummy_linear.weight.data = W
if dtype == torch.uint4:
q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)(dummy_linear).weight
else:
q_tensor_hqq = uintx_weight_only(dtype, group_size=max(block_size), use_hqq=True)(dummy_linear).weight

quant_linear_layer = torch.nn.Linear(W.shape[1], W.shape[0], bias=False, device=W.device)
del quant_linear_layer.weight
del quant_linear_layer.weight
quant_linear_layer.weight = q_tensor_hqq
dequantize_error = (W - q_tensor_hqq.dequantize()).abs().mean().item()
dot_product_error = (y_ref - quant_linear_layer(x.to(compute_dtype))).abs().mean().item()

return dequantize_error, dot_product_error


class TestHQQBase(unittest.TestCase):
@unittest.skipIf(not cuda_available, "Need CUDA available")
def test_hqq(self, nbits=None, layout_type=None, ref_dequantize_error=None, ref_dot_product_error=None):
if(nbits is None): return
dequantize_error, dot_product_error = _eval_hqq(nbits=nbits, layout_type=layout_type)
@unittest.skipIf(not cuda_available, "Need CUDA available")
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "Need torch 2.3+")
class TestHQQ(unittest.TestCase):
def _test_hqq(self, dtype=None, ref_dequantize_error=None, ref_dot_product_error=None):
if(dtype is None): return
dequantize_error, dot_product_error = _eval_hqq(dtype)
self.assertTrue(dequantize_error < ref_dequantize_error)
self.assertTrue(dot_product_error < ref_dot_product_error)

class TestHQQ8Bit(TestHQQBase):
def test_hqq_plain_8bit(self):
self.test_hqq(nbits=8, layout_type=PlainLayoutType(), ref_dequantize_error=5e-5, ref_dot_product_error=0.00013)
self._test_hqq(dtype=torch.uint8, ref_dequantize_error=5e-5, ref_dot_product_error=0.00013)

class TestHQQ7Bit(TestHQQBase):
def test_hqq_plain_7bit(self):
self.test_hqq(nbits=7, layout_type=PlainLayoutType(), ref_dequantize_error=6e-05, ref_dot_product_error=0.000193)
self._test_hqq(dtype=torch.uint7, ref_dequantize_error=6e-05, ref_dot_product_error=0.000193)

class TestHQQ6Bit(TestHQQBase):
def test_hqq_plain_6bit(self):
self.test_hqq(nbits=6, layout_type=PlainLayoutType(), ref_dequantize_error=0.0001131, ref_dot_product_error=0.000353)
self._test_hqq(dtype=torch.uint6, ref_dequantize_error=0.0001131, ref_dot_product_error=0.000353)

class TestHQQ5Bit(TestHQQBase):
def test_hqq_plain_5bit(self):
self.test_hqq(nbits=5, layout_type=PlainLayoutType(), ref_dequantize_error=0.00023, ref_dot_product_error=0.000704)
self._test_hqq(dtype=torch.uint5, ref_dequantize_error=0.00023, ref_dot_product_error=0.000704)

class TestHQQ4bit(TestHQQBase):
def test_hqq_plain_4bit(self):
self.test_hqq(nbits=4, layout_type=PlainLayoutType(), ref_dequantize_error=0.000487, ref_dot_product_error=0.001472)

def test_hqq_tensorcore_4bit(self):
self.test_hqq(nbits=4, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles), ref_dequantize_error=0.000487, ref_dot_product_error=0.00147)
self._test_hqq(dtype=torch.uint4, ref_dequantize_error=0.000487, ref_dot_product_error=0.001472)

class TestHQQ3Bit(TestHQQBase):
def test_hqq_plain_3bit(self):
self.test_hqq(nbits=3, layout_type=PlainLayoutType(), ref_dequantize_error=0.00101, ref_dot_product_error=0.003047)
self._test_hqq(dtype=torch.uint3, ref_dequantize_error=0.00101, ref_dot_product_error=0.003047)

class TestHQQ2Bit(TestHQQBase):
def test_hqq_plain_2bit(self):
self.test_hqq(nbits=2, layout_type=PlainLayoutType(), ref_dequantize_error=0.002366, ref_dot_product_error=0.007255)
self._test_hqq(dtype=torch.uint2, ref_dequantize_error=0.002366, ref_dot_product_error=0.007255)

if __name__ == "__main__":
unittest.main()
15 changes: 10 additions & 5 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,19 @@ def run_evaluation(
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
quantize_(model.to(device), int4_weight_only(group_size=groupsize, use_hqq=use_hqq))
if "uintx" in quantization:
# uintx-nbits-group_size
# uintx-nbits-groupsize
# "uintx-2-64"
if "hqq" in quantization:
use_hqq = True
quantization = quantization[:-4]
else:
use_hqq = False
_quant_args = quantization.split("-")
nbits = int(_quant_args[1])
nbits = int(_quant_args[0])
_NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8}
dtype = _NBITS_TO_DTYPE[nbits]
group_size = int(_quant_args[2])
quantize_(model, uintx_weight_only(dtype, group_size))
group_size = int(_quant_args[1])
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
if "int4wo" in quantization and "gptq" in quantization:
groupsize=int(quantization.split("-")[-2])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
Expand Down Expand Up @@ -135,7 +140,7 @@ def run_evaluation(
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, int4wo-<groupsize>-gptq, int4wo-<groupsize>-hqq, uintx-<nbits>-<group_size>")
parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, int4wo-<groupsize>-gptq, int4wo-<groupsize>-hqq, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq")
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
Expand Down
17 changes: 11 additions & 6 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,20 @@ def main(
if "fp6" in quantization:
quantize_(model, fpx_weight_only(3, 2))
if "uintx" in quantization:
# uintx-nbits-group_size
# "uintx-2-64"
# uintx-nbits-groupsize, e.g. "uintx-2-64"
if "hqq" in quantization:
# uintx-nbits-groupsize-hqq
quantization = quantization[:-4]
use_hqq = True
else:
use_hqq = False
_quant_args = quantization.split("-")
nbits = int(_quant_args[1])
nbits = int(_quant_args[0])
assert nbits >= 1 and nbits <= 8, "nbits must be 1 to 8"
_NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8}
dtype = _NBITS_TO_DTYPE[nbits]
group_size = int(_quant_args[2])
quantize_(model, uintx_weight_only(dtype, group_size))
group_size = int(_quant_args[1])
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
if "autoquant" in quantization:
if "autoquant-int4" == quantization:
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
Expand Down Expand Up @@ -451,7 +456,7 @@ def callback(x):
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>, uintx-<nbits>-<group_size>')
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq')
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size')
parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)')
Expand Down
4 changes: 2 additions & 2 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ZeroPointDomain,
MappingType,
int_scaled_matmul,
quantize_affine_hqq,
choose_qparams_and_quantize_affine_hqq,
FP8_TYPES,
choose_qparams_affine_fpx,
quantize_affine_fpx,
Expand Down Expand Up @@ -264,7 +264,7 @@ def from_hp_to_intx(
group_size = max(block_size)
compute_dtype = zero_point_dtype if (zero_point_dtype is not None) else input_float.dtype
device = input_float.device
data, scale, zero_point, _ = quantize_affine_hqq(input_float, nbits=nbits, group_size=group_size, axis=axis, compute_dtype=compute_dtype, device=device, verbose=False, raw_output=False)
data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq(input_float, nbits=nbits, group_size=group_size, axis=axis, compute_dtype=compute_dtype, device=device, verbose=False, raw_output=False)
data = data.to(target_dtype)
else:
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
Expand Down
7 changes: 6 additions & 1 deletion torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,12 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is
```python
# for torch 2.4+
from torchao.quantization import quantize_, int4_weight_only
quantize_(model, int4_weight_only())
group_size = 32

# you can enable [hqq](https://github.com/mobiusml/hqq/tree/master) quantization which is expected to improves accuracy through
# use_hqq flag for `int4_weight_only` quantization
use_hqq = False
quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this different from the way we enable auto-round? Which is its own function like apply_auto_round?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this depends on whether we want to just expose int4 weight only quant or all bitwidths. this PR just enables hqq for int4 so it's more convenient to just add this to existing int4_weight_only quant. but if we want to support all bitwidth, then we should follow what auto_round is doing.

cc @mobicham please let me know which one makes more sense

Copy link
Contributor

@mobicham mobicham Sep 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can keep that flag in int4_weight_only and have some call like this for the more general intx case ?

to_hqq_quantized(input_float, nbits: int, group_size: int):
    return to_affine_quantized_intx(
                    input_float=input_float,
                    mapping_type=MappingType.ASYMMETRIC,
                    block_size=(1, group_size),
                    target_dtype=torch.bfloat16,
                    quant_min=0,
                    quant_max=2**nbits - 1,
                    zero_point_domain=ZeroPointDomain.FLOAT,
                    preserve_zero=False,
                    layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8) if nbits in [4] else PlainLayoutType(),
                    use_hqq=True,
                    )

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My sense is we should separate out the implementation details from the algorithm name. Internally HQQ can be implemented by calling int4_weight_only but no reason to leak this detail to end users

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mobicham sure, that would align with what auto_round is doing now I think

@msaroufim you are also suggesting to have a separate hqq_weight_only(dtype, group_size, layout_type) method right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is user facing, should the current bool use_hqq instead be an enum, something like SingleWeightQuantizationAlgorithm.HQQ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this is not ideal, I'm planning to just have a separate hqq config and remove the flag

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo I'm integrating hqq into uintx_weight_only API now, and I'm keeping the boolean flag for now to keep it simpler, we can make this an enum if there are more algorithms in the future I think, please let me know if that sounds OK

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you're ok with potentially changing it later, sgtm


# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
Expand Down
38 changes: 31 additions & 7 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def input_quant_func(x: torch.Tensor):
return _get_linear_subclass_inserter(apply_float8_dynamic_activation_quant)


def uintx_weight_only(dtype, group_size=64, pack_dim=-1):
def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
"""
Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
x is the number of bits specified by `dtype`
Expand All @@ -606,23 +606,46 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1):
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
size is more fine grained, defaults to 64
`pack_dim`: the dimension we use for packing, defaults to -1
`use_hqq`: whether to use hqq algorithm or the default algorithm to quantize the weight
"""
def apply_uintx_weight_only_quant(weight):
layout_type = UintxLayoutType(dtype=dtype, pack_dim=pack_dim)
from torchao.quantization.quant_primitives import _DTYPE_TO_QVALUE_BOUNDS

SUPPORTED_DTYPES = {torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7, torch.uint8}
assert dtype in SUPPORTED_DTYPES, f"Unsupported dtype for hqq: {dtype}"

def apply_uintx_weight_only_quant(weight, dtype):
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int32
zero_point_domain = ZeroPointDomain.INT

if use_hqq:
if dtype == torch.uint4:
logger.warn(f"Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance")
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype]
dtype = torch.uint8
eps = None
zero_point_dtype = None
zero_point_domain = ZeroPointDomain.FLOAT
preserve_zero = False
layout_type = PlainLayoutType()
else:
quant_min, quant_max = None, None
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int32
zero_point_domain = ZeroPointDomain.INT
preserve_zero = True
layout_type = UintxLayoutType(dtype=dtype, pack_dim=pack_dim)

return to_affine_quantized_intx(
weight, mapping_type, block_size, dtype,
quant_min=quant_min, quant_max=quant_max,
eps=eps, zero_point_dtype=zero_point_dtype,
zero_point_domain=zero_point_domain,
preserve_zero=preserve_zero,
layout_type=layout_type,
use_hqq=use_hqq,
)

return _get_linear_subclass_inserter(apply_uintx_weight_only_quant)
return _get_linear_subclass_inserter(apply_uintx_weight_only_quant, dtype=dtype)

def fpx_weight_only(ebits: int, mbits: int):
"""Sub-byte floating point dtypes defined by `ebits`: exponent bits and `mbits`: mantissa bits
Expand Down Expand Up @@ -652,5 +675,6 @@ def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor:
return to_affine_quantized_fpx(weight, layout_type)
return _get_linear_subclass_inserter(apply_quant_llm)


if TORCH_VERSION_AT_LEAST_2_5:
torch.serialization.add_safe_globals([_int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant])
4 changes: 2 additions & 2 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"dequantize_affine_fpx",
"fake_quantize_affine",
"fake_quantize_affine_cachemask",
"quantize_affine_hqq",
"choose_qparams_and_quantize_affine_hqq",
]

class MappingType(Enum):
Expand Down Expand Up @@ -840,7 +840,7 @@ def _convert_to_affinequantized_format(W_q: torch.Tensor, scale: torch.Tensor, z
return W_q_ao, scale_ao, zero_ao

# Main hqq quantizer function
def quantize_affine_hqq(
def choose_qparams_and_quantize_affine_hqq(
tensor: torch.Tensor,
nbits: float = 4,
group_size: int = 64,
Expand Down
Loading