Skip to content

Commit 02ecdcf

Browse files
authored
add _keep_in_fp32_modules_strict (#39058)
* add _keep_in_fp32_modules_strict * complete test
1 parent d973e62 commit 02ecdcf

File tree

4 files changed

+111
-17
lines changed

4 files changed

+111
-17
lines changed

src/transformers/modeling_utils.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1937,7 +1937,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
19371937
_auto_class = None
19381938
_no_split_modules = None
19391939
_skip_keys_device_placement = None
1940+
19401941
_keep_in_fp32_modules = None
1942+
# the _keep_in_fp32_modules will avoid casting to anything other than float32, except bfloat16
1943+
# to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag
1944+
_keep_in_fp32_modules_strict = None
19411945

19421946
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
19431947
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
@@ -2049,6 +2053,7 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
20492053
# `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute
20502054
# when a different component (e.g. language_model) is used.
20512055
self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
2056+
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
20522057

20532058
self._no_split_modules = self._no_split_modules or []
20542059

@@ -2061,7 +2066,7 @@ def post_init(self):
20612066
self._backward_compatibility_gradient_checkpointing()
20622067

20632068
# Make sure the modules correctly exist if the flag is active
2064-
if self._keep_in_fp32_modules is not None:
2069+
if self._keep_in_fp32_modules is not None or self._keep_in_fp32_modules_strict is not None:
20652070
all_parameters = {name for name, _ in self.named_parameters() if len(name) > 0}
20662071
unique_module_names = set()
20672072
# Get all unique module names in the module graph, without the prefixes
@@ -2070,12 +2075,21 @@ def post_init(self):
20702075
[name for name in param.split(".") if not name.isnumeric() and name not in ["weight", "bias"]]
20712076
)
20722077
# Check that every module in the keep_in_fp32 list is part of the module graph
2073-
for module in self._keep_in_fp32_modules:
2074-
if module not in unique_module_names:
2075-
raise ValueError(
2076-
f"{module} was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in"
2077-
f" {self.__class__.__name__}"
2078-
)
2078+
if self._keep_in_fp32_modules is not None:
2079+
for module in self._keep_in_fp32_modules:
2080+
if module not in unique_module_names:
2081+
raise ValueError(
2082+
f"{module} was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in"
2083+
f" {self.__class__.__name__}"
2084+
)
2085+
2086+
if self._keep_in_fp32_modules_strict is not None:
2087+
for module in self._keep_in_fp32_modules_strict:
2088+
if module not in unique_module_names:
2089+
raise ValueError(
2090+
f"{module} was specified in the `_keep_in_fp32_modules_strict` list, but is not part of the modules in"
2091+
f" {self.__class__.__name__}"
2092+
)
20792093

20802094
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
20812095
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
@@ -4757,20 +4771,24 @@ def from_pretrained(
47574771
config = model.config
47584772

47594773
# Find fp32 modules if needed
4760-
keep_in_fp32_regex = None
4774+
keep_in_fp32_modules = []
47614775
# The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
47624776
# in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
47634777
# step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
4764-
# Update: to extend _keep_in_fp32_modules flag feature, it can also be used to force modules that should stay in fp32
47654778
if model._keep_in_fp32_modules is not None and (
4766-
torch_dtype == torch.float16
4767-
or torch_dtype == torch.bfloat16
4768-
or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
4779+
torch_dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
4780+
):
4781+
keep_in_fp32_modules.extend(model._keep_in_fp32_modules)
4782+
4783+
if model._keep_in_fp32_modules_strict is not None and (
4784+
torch_dtype == torch.float16 or torch_dtype == torch.bfloat16
47694785
):
4786+
keep_in_fp32_modules.extend(model._keep_in_fp32_modules_strict)
4787+
4788+
keep_in_fp32_regex = None
4789+
if keep_in_fp32_modules:
47704790
# We need to match exact layers, so we add either `.` on each side, or start/end of string
4771-
keep_in_fp32_regex = re.compile(
4772-
"|".join([rf"((^|\.){module}($|\.))" for module in model._keep_in_fp32_modules])
4773-
)
4791+
keep_in_fp32_regex = re.compile("|".join([rf"((^|\.){module}($|\.))" for module in keep_in_fp32_modules]))
47744792

47754793
if hf_quantizer is not None:
47764794
hf_quantizer.preprocess_model(

src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1103,7 +1103,7 @@ class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedMod
11031103
_tied_weights_keys = ["lm_head.weight"]
11041104
_tp_plan = {"lm_head": "colwise_rep"}
11051105
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
1106-
_keep_in_fp32_modules = ["codec_model"]
1106+
_keep_in_fp32_modules_strict = ["codec_model"]
11071107

11081108
def __init__(self, config):
11091109
super().__init__(config)

src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def __init__(self, config):
252252

253253

254254
class KyutaiSpeechToTextForConditionalGeneration(LlamaForCausalLM, GenerationMixin, PreTrainedModel):
255-
_keep_in_fp32_modules = ["codec_model"]
255+
_keep_in_fp32_modules_strict = ["codec_model"]
256256

257257
def __init__(self, config):
258258
super().__init__(config)

tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131
from transformers.testing_utils import (
3232
cleanup,
33+
require_accelerate,
3334
require_torch,
3435
require_torch_accelerator,
3536
require_torch_sdpa,
@@ -615,6 +616,81 @@ def _test_attention_implementation(self, attn_implementation):
615616
self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)
616617

617618

619+
@require_torch
620+
@require_accelerate
621+
@slow
622+
class KyutaiSpeechToTextBf16Test(unittest.TestCase):
623+
def test_bf16_fp32_conversion(self):
624+
r"""
625+
A test to check whether the argument `keep_in_fp32_modules` correctly does its job
626+
"""
627+
model_checkpoint = "kyutai/stt-2.6b-en-trfs"
628+
orig_import = __import__
629+
accelerate_mock = unittest.mock.Mock()
630+
631+
# mock import of accelerate
632+
def import_accelerate_mock(name, *args, **kwargs):
633+
if name == "accelerate":
634+
if accelerate_available:
635+
return accelerate_mock
636+
else:
637+
raise ImportError
638+
return orig_import(name, *args, **kwargs)
639+
640+
# Load without using `accelerate`
641+
with unittest.mock.patch("builtins.__import__", side_effect=import_accelerate_mock):
642+
accelerate_available = False
643+
644+
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
645+
model_checkpoint, torch_dtype=torch.float16
646+
)
647+
self.assertTrue(model.codec_model.dtype == torch.float32)
648+
self.assertTrue(model.model.dtype == torch.float16)
649+
self.assertTrue(model.lm_head.weight.data.dtype == torch.float16)
650+
651+
# Load without in bf16
652+
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
653+
model_checkpoint, torch_dtype=torch.bfloat16
654+
)
655+
self.assertTrue(model.codec_model.dtype == torch.float32)
656+
self.assertTrue(model.model.dtype == torch.bfloat16)
657+
self.assertTrue(model.lm_head.weight.data.dtype == torch.bfloat16)
658+
659+
# Load using `accelerate` in bf16
660+
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
661+
model_checkpoint, torch_dtype=torch.bfloat16, device_map="auto"
662+
)
663+
self.assertTrue(model.codec_model.dtype == torch.float32)
664+
self.assertTrue(model.model.dtype == torch.bfloat16)
665+
self.assertTrue(model.lm_head.weight.data.dtype == torch.bfloat16)
666+
667+
# Load using `accelerate` in bf16
668+
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
669+
model_checkpoint,
670+
torch_dtype=torch.bfloat16,
671+
)
672+
self.assertTrue(model.codec_model.dtype == torch.float32)
673+
self.assertTrue(model.model.dtype == torch.bfloat16)
674+
self.assertTrue(model.lm_head.weight.data.dtype == torch.bfloat16)
675+
676+
# Load without using `accelerate`
677+
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
678+
model_checkpoint,
679+
torch_dtype=torch.float16,
680+
)
681+
self.assertTrue(model.codec_model.dtype == torch.float32)
682+
self.assertTrue(model.model.dtype == torch.float16)
683+
self.assertTrue(model.lm_head.weight.data.dtype == torch.float16)
684+
685+
# Load using `accelerate`
686+
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
687+
model_checkpoint, torch_dtype=torch.float16, device_map="auto"
688+
)
689+
self.assertTrue(model.codec_model.dtype == torch.float32)
690+
self.assertTrue(model.model.dtype == torch.float16)
691+
self.assertTrue(model.lm_head.weight.data.dtype == torch.float16)
692+
693+
618694
class KyutaiSpeechToTextForConditionalGenerationIntegrationTests(unittest.TestCase):
619695
_dataset = None
620696

0 commit comments

Comments
 (0)