Skip to content

Commit a73180d

Browse files
committed
Implicit conversion
1 parent 62bd5da commit a73180d

File tree

4 files changed

+36
-21
lines changed

4 files changed

+36
-21
lines changed

test/float8/test_base.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch
1515
import torch.nn as nn
1616

17-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
17+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89
1818

1919
if not TORCH_VERSION_AT_LEAST_2_5:
2020
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -531,6 +531,21 @@ def test_inference_mode(self):
531531
with torch.inference_mode(mode=True):
532532
m(x)
533533

534+
@unittest.skipIf(not is_sm_89(), "CUDA arch 8.9 not available")
535+
def test_quantize(self):
536+
x = torch.randn(32, 32, device="cuda")
537+
m = nn.Sequential(nn.Linear(32, 32)).cuda()
538+
m = convert_to_float8_training(m)
539+
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
540+
from torchao.quantization.quant_api import float8_weight_only, quantize_
541+
542+
quantize_(m, float8_weight_only())
543+
assert (
544+
m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn
545+
), "Post quantization dtype should be torch.float8_e4m3fn"
546+
with torch.no_grad():
547+
m(x)
548+
534549

535550
class TestScaledMM:
536551
@unittest.skipIf(

torchao/float8/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from torchao.float8.float8_linear import WeightWithDelayedFloat8CastTensor
1515
from torchao.float8.float8_linear_utils import (
1616
convert_to_float8_training,
17-
dequantize_float8_training,
1817
linear_requires_sync,
1918
sync_float8_amax_and_scale_history,
2019
)
@@ -55,6 +54,5 @@
5554
"linear_requires_sync",
5655
"sync_float8_amax_and_scale_history",
5756
"precompute_float8_dynamic_scale_for_fsdp",
58-
"dequantize_float8_training",
5957
# note: Float8Tensor and Float8Linear are not public APIs
6058
]

torchao/float8/float8_linear_utils.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -321,21 +321,3 @@ def inner_func():
321321
for child in fp8_layers:
322322
# Set a flag to signal that initialization is done
323323
child.is_amax_initialized = True
324-
325-
326-
def dequantize_float8_training(model: nn.Module) -> nn.Module:
327-
"""
328-
Converts `Float8Linear` modules in `model` to `torch.nn.Linear`.
329-
"""
330-
331-
def dequant_func(mod: Float8Linear) -> nn.Linear:
332-
new_module = nn.Linear(mod.in_features, mod.out_features)
333-
new_module.weight = mod.weight
334-
new_module.bias = mod.bias
335-
return new_module
336-
337-
return swap_linear_layers(
338-
model,
339-
dequant_func,
340-
target_module=Float8Linear,
341-
)

torchao/quantization/quant_api.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
to_affine_quantized_intx,
4040
to_marlinqqq_quantized_intx,
4141
)
42+
from torchao.float8.float8_linear import Float8Linear
43+
from torchao.float8.float8_linear_utils import swap_linear_layers
4244
from torchao.float8.inference import Float8MMConfig
4345
from torchao.quantization.linear_activation_weight_observed_tensor import (
4446
LinearActivationWeightObservedTensor,
@@ -199,6 +201,22 @@ def change_linear_weights_to_int4_woqtensors(
199201
########
200202
# TO BE DEPRECATED END
201203
########
204+
def dequantize_float8_training(model: nn.Module) -> nn.Module:
205+
"""
206+
Converts `Float8Linear` modules in `model` to `torch.nn.Linear`.
207+
"""
208+
209+
def dequant_func(mod: Float8Linear) -> nn.Linear:
210+
new_module = nn.Linear(mod.in_features, mod.out_features)
211+
new_module.weight = mod.weight
212+
new_module.bias = mod.bias
213+
return new_module
214+
215+
return swap_linear_layers(
216+
model,
217+
dequant_func,
218+
target_module=Float8Linear,
219+
)
202220

203221

204222
def _replace_with_custom_fn_if_matches_filter(
@@ -222,6 +240,8 @@ def _replace_with_custom_fn_if_matches_filter(
222240
Returns:
223241
None
224242
"""
243+
if isinstance(model, Float8Linear):
244+
model = dequantize_float8_training(model)
225245
if filter_fn(model, cur_fqn[:-1]):
226246
if device is not None:
227247
model.to(device=device) # move to device before quantization

0 commit comments

Comments
 (0)