diff --git a/tests/recipes/test_configs.py b/tests/recipes/test_configs.py index f0835434de..fcb242ec4c 100644 --- a/tests/recipes/test_configs.py +++ b/tests/recipes/test_configs.py @@ -9,8 +9,8 @@ import torchtune from omegaconf import OmegaConf -from torchao.utils import TORCH_VERSION_AFTER_2_4 from torchtune import config +from torchtune.utils import torch_version_ge CONFIG_DIR = Path(torchtune.__file__).parent.parent / "recipes" / "configs" @@ -24,7 +24,7 @@ def test_instantiate(self) -> None: ] for config_path in all_configs: # QAT config is only compatible with PyTorch 2.4+ - if config_path.endswith("qat_full.yaml") and not TORCH_VERSION_AFTER_2_4: + if config_path.endswith("qat_full.yaml") and not torch_version_ge("2.4.0"): continue cfg = OmegaConf.load(config_path) config.validate(cfg) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 4e884ed86d..5c46804357 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -19,20 +19,15 @@ Float8RowwiseParallel, ) from torchao.quantization import ( - int4_weight_only, - int8_dynamic_activation_int4_weight, + Int4WeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, quantize_, ) from torchao.quantization.qat import ( Int4WeightOnlyQATQuantizer, Int8DynActInt4WeightQATQuantizer, ) -from torchao.quantization.qat.linear import ( - disable_4w_fake_quant, - disable_8da4w_fake_quant, - enable_4w_fake_quant, - enable_8da4w_fake_quant, -) + from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear @@ -58,6 +53,24 @@ _quantizer_mode_to_enable_fake_quant = {} +def _enable_linear_fake_quant( + mod: nn.Module, + enabled: bool = True, +): + """ + Helper function to enable fake quantization in `FakeQuantizedLinear`. + """ + if isinstance(mod, FakeQuantizedLinear): + if mod.activation_fake_quantizer is not None: + mod.activation_fake_quantizer.enabled = enabled + if mod.weight_fake_quantizer is not None: + mod.weight_fake_quantizer.enabled = enabled + + +def _disable_linear_fake_quant(mod: nn.Module): + _enable_linear_fake_quant(mod, enabled=False) + + # ======================================== # int8 dynamic activations + int4 weight | # ======================================== @@ -73,15 +86,15 @@ def __init__(self, groupsize: int = 256): self.groupsize = groupsize def quantize(self, model): - quantize_fn = int8_dynamic_activation_int4_weight(self.groupsize) + quantize_fn = Int8DynamicActivationInt4WeightConfig(self.groupsize) quantize_(model, quantize_fn) return model _quantizer_to_mode[Int8DynActInt4WeightQuantizer] = "8da4w" _quantizer_to_mode[Int8DynActInt4WeightQATQuantizer] = "8da4w-qat" -_quantizer_mode_to_disable_fake_quant["8da4w-qat"] = disable_8da4w_fake_quant -_quantizer_mode_to_enable_fake_quant["8da4w-qat"] = enable_8da4w_fake_quant +_quantizer_mode_to_disable_fake_quant["8da4w-qat"] = _disable_linear_fake_quant +_quantizer_mode_to_enable_fake_quant["8da4w-qat"] = _enable_linear_fake_quant # ================== @@ -101,15 +114,15 @@ def __init__(self, groupsize: int = 128, inner_k_tiles: int = 8): def quantize(self, model): layout_type = TensorCoreTiledLayout(self.inner_k_tiles) - quantize_fn = int4_weight_only(self.groupsize, layout_type) + quantize_fn = Int4WeightOnlyConfig(self.groupsize, layout_type) quantize_(model, quantize_fn) return model _quantizer_to_mode[Int4WeightOnlyQuantizer] = "4w" _quantizer_to_mode[Int4WeightOnlyQATQuantizer] = "4w-qat" -_quantizer_mode_to_disable_fake_quant["4w-qat"] = disable_4w_fake_quant -_quantizer_mode_to_enable_fake_quant["4w-qat"] = enable_4w_fake_quant +_quantizer_mode_to_disable_fake_quant["4w-qat"] = _disable_linear_fake_quant +_quantizer_mode_to_enable_fake_quant["4w-qat"] = _enable_linear_fake_quant # ====================== # @@ -122,8 +135,8 @@ class Int4WeightOnlyQATQuantizerModuleSwap(Int4WeightOnlyQATQuantizer): pass -disable_4w_fake_quant_module_swap = disable_4w_fake_quant -enable_4w_fake_quant_module_swap = enable_4w_fake_quant +disable_4w_fake_quant_module_swap = _disable_linear_fake_quant +enable_4w_fake_quant_module_swap = _enable_linear_fake_quant _quantizer_to_mode[Int4WeightOnlyQATQuantizerModuleSwap] = "4w-qat-module-swap" _quantizer_mode_to_disable_fake_quant[ "4w-qat-module-swap" @@ -138,8 +151,8 @@ class Int8DynActInt4WeightQATQuantizerModuleSwap(Int8DynActInt4WeightQATQuantize pass -disable_8da4w_fake_quant_module_swap = disable_8da4w_fake_quant -enable_8da4w_fake_quant_module_swap = enable_8da4w_fake_quant +disable_8da4w_fake_quant_module_swap = _disable_linear_fake_quant +enable_8da4w_fake_quant_module_swap = _enable_linear_fake_quant _quantizer_to_mode[Int8DynActInt4WeightQATQuantizerModuleSwap] = "8da4w-qat-module-swap" _quantizer_mode_to_disable_fake_quant[ "8da4w-qat-module-swap"