Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/recipes/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import torchtune

from omegaconf import OmegaConf
from torchao.utils import TORCH_VERSION_AFTER_2_4
from torchtune import config
from torchtune.utils import torch_version_ge

CONFIG_DIR = Path(torchtune.__file__).parent.parent / "recipes" / "configs"

Expand All @@ -24,7 +24,7 @@ def test_instantiate(self) -> None:
]
for config_path in all_configs:
# QAT config is only compatible with PyTorch 2.4+
if config_path.endswith("qat_full.yaml") and not TORCH_VERSION_AFTER_2_4:
if config_path.endswith("qat_full.yaml") and not torch_version_ge("2.4.0"):
continue
cfg = OmegaConf.load(config_path)
config.validate(cfg)
49 changes: 31 additions & 18 deletions torchtune/training/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,15 @@
Float8RowwiseParallel,
)
from torchao.quantization import (
int4_weight_only,
int8_dynamic_activation_int4_weight,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
quantize_,
)
from torchao.quantization.qat import (
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)
from torchao.quantization.qat.linear import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
)


from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear

Expand All @@ -58,6 +53,24 @@
_quantizer_mode_to_enable_fake_quant = {}


def _enable_linear_fake_quant(
mod: nn.Module,
enabled: bool = True,
):
"""
Helper function to enable fake quantization in `FakeQuantizedLinear`.
"""
if isinstance(mod, FakeQuantizedLinear):
if mod.activation_fake_quantizer is not None:
mod.activation_fake_quantizer.enabled = enabled
if mod.weight_fake_quantizer is not None:
mod.weight_fake_quantizer.enabled = enabled


def _disable_linear_fake_quant(mod: nn.Module):
_enable_linear_fake_quant(mod, enabled=False)


# ========================================
# int8 dynamic activations + int4 weight |
# ========================================
Expand All @@ -73,15 +86,15 @@ def __init__(self, groupsize: int = 256):
self.groupsize = groupsize

def quantize(self, model):
quantize_fn = int8_dynamic_activation_int4_weight(self.groupsize)
quantize_fn = Int8DynamicActivationInt4WeightConfig(self.groupsize)
quantize_(model, quantize_fn)
return model


_quantizer_to_mode[Int8DynActInt4WeightQuantizer] = "8da4w"
_quantizer_to_mode[Int8DynActInt4WeightQATQuantizer] = "8da4w-qat"
_quantizer_mode_to_disable_fake_quant["8da4w-qat"] = disable_8da4w_fake_quant
_quantizer_mode_to_enable_fake_quant["8da4w-qat"] = enable_8da4w_fake_quant
_quantizer_mode_to_disable_fake_quant["8da4w-qat"] = _disable_linear_fake_quant
_quantizer_mode_to_enable_fake_quant["8da4w-qat"] = _enable_linear_fake_quant


# ==================
Expand All @@ -101,15 +114,15 @@ def __init__(self, groupsize: int = 128, inner_k_tiles: int = 8):

def quantize(self, model):
layout_type = TensorCoreTiledLayout(self.inner_k_tiles)
quantize_fn = int4_weight_only(self.groupsize, layout_type)
quantize_fn = Int4WeightOnlyConfig(self.groupsize, layout_type)
quantize_(model, quantize_fn)
return model


_quantizer_to_mode[Int4WeightOnlyQuantizer] = "4w"
_quantizer_to_mode[Int4WeightOnlyQATQuantizer] = "4w-qat"
_quantizer_mode_to_disable_fake_quant["4w-qat"] = disable_4w_fake_quant
_quantizer_mode_to_enable_fake_quant["4w-qat"] = enable_4w_fake_quant
_quantizer_mode_to_disable_fake_quant["4w-qat"] = _disable_linear_fake_quant
_quantizer_mode_to_enable_fake_quant["4w-qat"] = _enable_linear_fake_quant


# ====================== #
Expand All @@ -122,8 +135,8 @@ class Int4WeightOnlyQATQuantizerModuleSwap(Int4WeightOnlyQATQuantizer):
pass


disable_4w_fake_quant_module_swap = disable_4w_fake_quant
enable_4w_fake_quant_module_swap = enable_4w_fake_quant
disable_4w_fake_quant_module_swap = _disable_linear_fake_quant
enable_4w_fake_quant_module_swap = _enable_linear_fake_quant
_quantizer_to_mode[Int4WeightOnlyQATQuantizerModuleSwap] = "4w-qat-module-swap"
_quantizer_mode_to_disable_fake_quant[
"4w-qat-module-swap"
Expand All @@ -138,8 +151,8 @@ class Int8DynActInt4WeightQATQuantizerModuleSwap(Int8DynActInt4WeightQATQuantize
pass


disable_8da4w_fake_quant_module_swap = disable_8da4w_fake_quant
enable_8da4w_fake_quant_module_swap = enable_8da4w_fake_quant
disable_8da4w_fake_quant_module_swap = _disable_linear_fake_quant
enable_8da4w_fake_quant_module_swap = _enable_linear_fake_quant
_quantizer_to_mode[Int8DynActInt4WeightQATQuantizerModuleSwap] = "8da4w-qat-module-swap"
_quantizer_mode_to_disable_fake_quant[
"8da4w-qat-module-swap"
Expand Down