From 00f9137eded6f023709bc6a726d4d004b1c200f5 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 25 Nov 2024 16:49:32 -0800 Subject: [PATCH 1/5] Add support for quantize_() with Float8Linear module --- torchao/float8/float8_linear_utils.py | 29 ++++++++++++++++++++++----- torchao/quantization/quant_api.py | 5 +++++ 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py index c4fc88eb37..5890f7bf6b 100644 --- a/torchao/float8/float8_linear_utils.py +++ b/torchao/float8/float8_linear_utils.py @@ -59,6 +59,7 @@ def _update_history_stack( def swap_linear_layers( module: nn.Module, from_float_func: Callable[[nn.Linear], nn.Linear], + target_module: nn.Module = nn.Linear, *, module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, ) -> nn.Module: @@ -71,20 +72,21 @@ def swap_linear_layers( Args: module: Module to modify. - from_float_func: Function that accepts a linear layer and returns a new type of linear layer. + from_float_func: Function that accepts some type of linear layer and returns a new type of linear layer. module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that that pass the filter function will be swapped. The inputs to the filter function are the module instance, and the FQN. + target_module: Replace these modules Returns: nn.Module: The modified module with swapped linear layers. """ - if isinstance(module, nn.Linear) and ( + if isinstance(module, target_module) and ( module_filter_fn is None or module_filter_fn(module, "") ): if len(list(module.children())) > 0: raise AssertionError( - f"Does not support a root nn.Linear with children: {module}" + f"Does not support a root {target_module} with children: {module}" ) return from_float_func( module, @@ -108,12 +110,12 @@ def post_order_traversal( post_order_traversal(child_module, new_fqn, module) - if isinstance(module, nn.Linear) and ( + if isinstance(module, target_module) and ( module_filter_fn is None or module_filter_fn(module, cur_fqn) ): assert ( parent_module is not None - ), f"Linear root module should return early: {module}" + ), f"{target_module} root module should return early: {module}" new_linear_module = from_float_func(module) cur_module_name = cur_fqn.split(".")[-1] setattr(parent_module, cur_module_name, new_linear_module) @@ -319,3 +321,20 @@ def inner_func(): for child in fp8_layers: # Set a flag to signal that initialization is done child.is_amax_initialized = True + +def dequantize_float8_training(model: nn.Module) -> nn.Module: + """ + Converts `Float8Linear` modules in `model` to `torch.nn.Linear`. + """ + + def dequant_func(mod: Float8Linear) -> nn.Linear: + new_module = nn.Linear(mod.in_features, mod.out_features) + new_module.weight = mod.weight + new_module.bias = mod.bias + return new_module + + return swap_linear_layers( + model, + dequant_func, + target_module=Float8Linear, + ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index c730ec9046..e5ab1de5dc 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -39,6 +39,8 @@ to_affine_quantized_intx, to_marlinqqq_quantized_intx, ) +from torchao.float8.float8_linear import Float8Linear +from torchao.float8.float8_linear_utils import dequantize_float8_training from torchao.float8.inference import Float8MMConfig from torchao.quantization.linear_activation_weight_observed_tensor import ( LinearActivationWeightObservedTensor, @@ -222,6 +224,9 @@ def _replace_with_custom_fn_if_matches_filter( Returns: None """ + # If model is Float8Linear, convert it to Linear before moving forward + if isinstance(model, Float8Linear): + model = dequantize_float8_training(model) if filter_fn(model, cur_fqn[:-1]): if device is not None: model.to(device=device) # move to device before quantization From 62bd5da2a009aa10fcd8ba1c1cb1fdfcb1f23a75 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 26 Nov 2024 13:18:11 -0800 Subject: [PATCH 2/5] Expose fp8linear dequantize api --- test/float8/test_base.py | 2 +- torchao/float8/__init__.py | 2 ++ torchao/float8/float8_linear_utils.py | 5 +++-- torchao/quantization/quant_api.py | 5 ----- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index d00b96d3bb..43d01c6ffb 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -751,7 +751,7 @@ def test_swap_root_linear_with_children_raises(self): config = Float8LinearConfig(emulate=emulate) with self.assertRaisesRegex( AssertionError, - "Does not support a root nn.Linear with children", + "Does not support a root torch.nn.modules.linear with children", ): convert_to_float8_training(module, config=config) diff --git a/torchao/float8/__init__.py b/torchao/float8/__init__.py index bb8b38c0b9..b7ae3abd2b 100644 --- a/torchao/float8/__init__.py +++ b/torchao/float8/__init__.py @@ -14,6 +14,7 @@ from torchao.float8.float8_linear import WeightWithDelayedFloat8CastTensor from torchao.float8.float8_linear_utils import ( convert_to_float8_training, + dequantize_float8_training, linear_requires_sync, sync_float8_amax_and_scale_history, ) @@ -54,5 +55,6 @@ "linear_requires_sync", "sync_float8_amax_and_scale_history", "precompute_float8_dynamic_scale_for_fsdp", + "dequantize_float8_training", # note: Float8Tensor and Float8Linear are not public APIs ] diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py index 5890f7bf6b..7f7d271745 100644 --- a/torchao/float8/float8_linear_utils.py +++ b/torchao/float8/float8_linear_utils.py @@ -59,9 +59,9 @@ def _update_history_stack( def swap_linear_layers( module: nn.Module, from_float_func: Callable[[nn.Linear], nn.Linear], - target_module: nn.Module = nn.Linear, *, module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, + target_module: nn.Module = nn.Linear, ) -> nn.Module: """ Generic function to swap linear layers in a module with a new type of linear layer. @@ -86,7 +86,7 @@ def swap_linear_layers( ): if len(list(module.children())) > 0: raise AssertionError( - f"Does not support a root {target_module} with children: {module}" + f"Does not support a root {target_module.__module__} with children: {module.__module__}" ) return from_float_func( module, @@ -322,6 +322,7 @@ def inner_func(): # Set a flag to signal that initialization is done child.is_amax_initialized = True + def dequantize_float8_training(model: nn.Module) -> nn.Module: """ Converts `Float8Linear` modules in `model` to `torch.nn.Linear`. diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e5ab1de5dc..c730ec9046 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -39,8 +39,6 @@ to_affine_quantized_intx, to_marlinqqq_quantized_intx, ) -from torchao.float8.float8_linear import Float8Linear -from torchao.float8.float8_linear_utils import dequantize_float8_training from torchao.float8.inference import Float8MMConfig from torchao.quantization.linear_activation_weight_observed_tensor import ( LinearActivationWeightObservedTensor, @@ -224,9 +222,6 @@ def _replace_with_custom_fn_if_matches_filter( Returns: None """ - # If model is Float8Linear, convert it to Linear before moving forward - if isinstance(model, Float8Linear): - model = dequantize_float8_training(model) if filter_fn(model, cur_fqn[:-1]): if device is not None: model.to(device=device) # move to device before quantization From a73180d760e12c2d41e5165119a99a9ba00893a7 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 26 Nov 2024 16:12:53 -0800 Subject: [PATCH 3/5] Implicit conversion --- test/float8/test_base.py | 17 ++++++++++++++++- torchao/float8/__init__.py | 2 -- torchao/float8/float8_linear_utils.py | 18 ------------------ torchao/quantization/quant_api.py | 20 ++++++++++++++++++++ 4 files changed, 36 insertions(+), 21 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 43d01c6ffb..4d589113ef 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -14,7 +14,7 @@ import torch import torch.nn as nn -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89 if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -531,6 +531,21 @@ def test_inference_mode(self): with torch.inference_mode(mode=True): m(x) + @unittest.skipIf(not is_sm_89(), "CUDA arch 8.9 not available") + def test_quantize(self): + x = torch.randn(32, 32, device="cuda") + m = nn.Sequential(nn.Linear(32, 32)).cuda() + m = convert_to_float8_training(m) + assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear" + from torchao.quantization.quant_api import float8_weight_only, quantize_ + + quantize_(m, float8_weight_only()) + assert ( + m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn + ), "Post quantization dtype should be torch.float8_e4m3fn" + with torch.no_grad(): + m(x) + class TestScaledMM: @unittest.skipIf( diff --git a/torchao/float8/__init__.py b/torchao/float8/__init__.py index b7ae3abd2b..bb8b38c0b9 100644 --- a/torchao/float8/__init__.py +++ b/torchao/float8/__init__.py @@ -14,7 +14,6 @@ from torchao.float8.float8_linear import WeightWithDelayedFloat8CastTensor from torchao.float8.float8_linear_utils import ( convert_to_float8_training, - dequantize_float8_training, linear_requires_sync, sync_float8_amax_and_scale_history, ) @@ -55,6 +54,5 @@ "linear_requires_sync", "sync_float8_amax_and_scale_history", "precompute_float8_dynamic_scale_for_fsdp", - "dequantize_float8_training", # note: Float8Tensor and Float8Linear are not public APIs ] diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py index 7f7d271745..5e2a4f68f6 100644 --- a/torchao/float8/float8_linear_utils.py +++ b/torchao/float8/float8_linear_utils.py @@ -321,21 +321,3 @@ def inner_func(): for child in fp8_layers: # Set a flag to signal that initialization is done child.is_amax_initialized = True - - -def dequantize_float8_training(model: nn.Module) -> nn.Module: - """ - Converts `Float8Linear` modules in `model` to `torch.nn.Linear`. - """ - - def dequant_func(mod: Float8Linear) -> nn.Linear: - new_module = nn.Linear(mod.in_features, mod.out_features) - new_module.weight = mod.weight - new_module.bias = mod.bias - return new_module - - return swap_linear_layers( - model, - dequant_func, - target_module=Float8Linear, - ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index c730ec9046..a59a38b2b0 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -39,6 +39,8 @@ to_affine_quantized_intx, to_marlinqqq_quantized_intx, ) +from torchao.float8.float8_linear import Float8Linear +from torchao.float8.float8_linear_utils import swap_linear_layers from torchao.float8.inference import Float8MMConfig from torchao.quantization.linear_activation_weight_observed_tensor import ( LinearActivationWeightObservedTensor, @@ -199,6 +201,22 @@ def change_linear_weights_to_int4_woqtensors( ######## # TO BE DEPRECATED END ######## +def dequantize_float8_training(model: nn.Module) -> nn.Module: + """ + Converts `Float8Linear` modules in `model` to `torch.nn.Linear`. + """ + + def dequant_func(mod: Float8Linear) -> nn.Linear: + new_module = nn.Linear(mod.in_features, mod.out_features) + new_module.weight = mod.weight + new_module.bias = mod.bias + return new_module + + return swap_linear_layers( + model, + dequant_func, + target_module=Float8Linear, + ) def _replace_with_custom_fn_if_matches_filter( @@ -222,6 +240,8 @@ def _replace_with_custom_fn_if_matches_filter( Returns: None """ + if isinstance(model, Float8Linear): + model = dequantize_float8_training(model) if filter_fn(model, cur_fqn[:-1]): if device is not None: model.to(device=device) # move to device before quantization From dfdba9267d6a8085736af900a7bb0a431ee92e2c Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 27 Nov 2024 15:42:10 -0800 Subject: [PATCH 4/5] Inline function --- torchao/quantization/quant_api.py | 32 +++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e626489d35..dba33da2db 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -201,22 +201,6 @@ def change_linear_weights_to_int4_woqtensors( ######## # TO BE DEPRECATED END ######## -def dequantize_float8_training(model: nn.Module) -> nn.Module: - """ - Converts `Float8Linear` modules in `model` to `torch.nn.Linear`. - """ - - def dequant_func(mod: Float8Linear) -> nn.Linear: - new_module = nn.Linear(mod.in_features, mod.out_features) - new_module.weight = mod.weight - new_module.bias = mod.bias - return new_module - - return swap_linear_layers( - model, - dequant_func, - target_module=Float8Linear, - ) def _replace_with_custom_fn_if_matches_filter( @@ -240,6 +224,22 @@ def _replace_with_custom_fn_if_matches_filter( Returns: None """ + + def dequantize_float8_training(model: nn.Module) -> nn.Module: + """Converts `Float8Linear` modules in `model` to `torch.nn.Linear`.""" + + def dequant_func(mod: Float8Linear) -> nn.Linear: + new_module = nn.Linear(mod.in_features, mod.out_features) + new_module.weight = mod.weight + new_module.bias = mod.bias + return new_module + + return swap_linear_layers( + model, + dequant_func, + target_module=Float8Linear, + ) + if isinstance(model, Float8Linear): model = dequantize_float8_training(model) if filter_fn(model, cur_fqn[:-1]): From f66e7ed5042864ad93f8ddb95d20dd3e477bf62c Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 27 Nov 2024 16:41:44 -0800 Subject: [PATCH 5/5] Updated logic' --- test/float8/test_base.py | 2 +- torchao/float8/float8_linear_utils.py | 12 +++++------- torchao/quantization/quant_api.py | 23 +++++------------------ 3 files changed, 11 insertions(+), 26 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index e07e936bcf..245abe0d02 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -766,7 +766,7 @@ def test_swap_root_linear_with_children_raises(self): config = Float8LinearConfig(emulate=emulate) with self.assertRaisesRegex( AssertionError, - "Does not support a root torch.nn.modules.linear with children", + "Does not support a root nn.Linear with children", ): convert_to_float8_training(module, config=config) diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py index 5e2a4f68f6..c4fc88eb37 100644 --- a/torchao/float8/float8_linear_utils.py +++ b/torchao/float8/float8_linear_utils.py @@ -61,7 +61,6 @@ def swap_linear_layers( from_float_func: Callable[[nn.Linear], nn.Linear], *, module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, - target_module: nn.Module = nn.Linear, ) -> nn.Module: """ Generic function to swap linear layers in a module with a new type of linear layer. @@ -72,21 +71,20 @@ def swap_linear_layers( Args: module: Module to modify. - from_float_func: Function that accepts some type of linear layer and returns a new type of linear layer. + from_float_func: Function that accepts a linear layer and returns a new type of linear layer. module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that that pass the filter function will be swapped. The inputs to the filter function are the module instance, and the FQN. - target_module: Replace these modules Returns: nn.Module: The modified module with swapped linear layers. """ - if isinstance(module, target_module) and ( + if isinstance(module, nn.Linear) and ( module_filter_fn is None or module_filter_fn(module, "") ): if len(list(module.children())) > 0: raise AssertionError( - f"Does not support a root {target_module.__module__} with children: {module.__module__}" + f"Does not support a root nn.Linear with children: {module}" ) return from_float_func( module, @@ -110,12 +108,12 @@ def post_order_traversal( post_order_traversal(child_module, new_fqn, module) - if isinstance(module, target_module) and ( + if isinstance(module, nn.Linear) and ( module_filter_fn is None or module_filter_fn(module, cur_fqn) ): assert ( parent_module is not None - ), f"{target_module} root module should return early: {module}" + ), f"Linear root module should return early: {module}" new_linear_module = from_float_func(module) cur_module_name = cur_fqn.split(".")[-1] setattr(parent_module, cur_module_name, new_linear_module) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index dba33da2db..60a7341e39 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -40,7 +40,6 @@ to_marlinqqq_quantized_intx, ) from torchao.float8.float8_linear import Float8Linear -from torchao.float8.float8_linear_utils import swap_linear_layers from torchao.float8.inference import Float8MMConfig from torchao.quantization.linear_activation_weight_observed_tensor import ( LinearActivationWeightObservedTensor, @@ -224,24 +223,12 @@ def _replace_with_custom_fn_if_matches_filter( Returns: None """ - - def dequantize_float8_training(model: nn.Module) -> nn.Module: - """Converts `Float8Linear` modules in `model` to `torch.nn.Linear`.""" - - def dequant_func(mod: Float8Linear) -> nn.Linear: - new_module = nn.Linear(mod.in_features, mod.out_features) - new_module.weight = mod.weight - new_module.bias = mod.bias - return new_module - - return swap_linear_layers( - model, - dequant_func, - target_module=Float8Linear, - ) - if isinstance(model, Float8Linear): - model = dequantize_float8_training(model) + with torch.device("meta"): + new_module = nn.Linear(model.in_features, model.out_features) + new_module.weight = model.weight + new_module.bias = model.bias + model = new_module if filter_fn(model, cur_fqn[:-1]): if device is not None: model.to(device=device) # move to device before quantization