@@ -1937,7 +1937,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
19371937 _auto_class = None
19381938 _no_split_modules = None
19391939 _skip_keys_device_placement = None
1940+
19401941 _keep_in_fp32_modules = None
1942+ # the _keep_in_fp32_modules will avoid casting to anything other than float32, except bfloat16
1943+ # to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag
1944+ _keep_in_fp32_modules_strict = None
19411945
19421946 # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
19431947 # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
@@ -2049,6 +2053,7 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
20492053 # `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute
20502054 # when a different component (e.g. language_model) is used.
20512055 self ._keep_in_fp32_modules = copy .copy (self .__class__ ._keep_in_fp32_modules )
2056+ self ._keep_in_fp32_modules_strict = copy .copy (self .__class__ ._keep_in_fp32_modules_strict )
20522057
20532058 self ._no_split_modules = self ._no_split_modules or []
20542059
@@ -2061,7 +2066,7 @@ def post_init(self):
20612066 self ._backward_compatibility_gradient_checkpointing ()
20622067
20632068 # Make sure the modules correctly exist if the flag is active
2064- if self ._keep_in_fp32_modules is not None :
2069+ if self ._keep_in_fp32_modules is not None or self . _keep_in_fp32_modules_strict is not None :
20652070 all_parameters = {name for name , _ in self .named_parameters () if len (name ) > 0 }
20662071 unique_module_names = set ()
20672072 # Get all unique module names in the module graph, without the prefixes
@@ -2070,12 +2075,21 @@ def post_init(self):
20702075 [name for name in param .split ("." ) if not name .isnumeric () and name not in ["weight" , "bias" ]]
20712076 )
20722077 # Check that every module in the keep_in_fp32 list is part of the module graph
2073- for module in self ._keep_in_fp32_modules :
2074- if module not in unique_module_names :
2075- raise ValueError (
2076- f"{ module } was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in"
2077- f" { self .__class__ .__name__ } "
2078- )
2078+ if self ._keep_in_fp32_modules is not None :
2079+ for module in self ._keep_in_fp32_modules :
2080+ if module not in unique_module_names :
2081+ raise ValueError (
2082+ f"{ module } was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in"
2083+ f" { self .__class__ .__name__ } "
2084+ )
2085+
2086+ if self ._keep_in_fp32_modules_strict is not None :
2087+ for module in self ._keep_in_fp32_modules_strict :
2088+ if module not in unique_module_names :
2089+ raise ValueError (
2090+ f"{ module } was specified in the `_keep_in_fp32_modules_strict` list, but is not part of the modules in"
2091+ f" { self .__class__ .__name__ } "
2092+ )
20792093
20802094 # If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
20812095 self ._pp_plan = self .config .base_model_pp_plan .copy () if self .config .base_model_pp_plan is not None else None
@@ -4757,20 +4771,24 @@ def from_pretrained(
47574771 config = model .config
47584772
47594773 # Find fp32 modules if needed
4760- keep_in_fp32_regex = None
4774+ keep_in_fp32_modules = []
47614775 # The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
47624776 # in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
47634777 # step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
4764- # Update: to extend _keep_in_fp32_modules flag feature, it can also be used to force modules that should stay in fp32
47654778 if model ._keep_in_fp32_modules is not None and (
4766- torch_dtype == torch .float16
4767- or torch_dtype == torch .bfloat16
4768- or getattr (hf_quantizer , "use_keep_in_fp32_modules" , False )
4779+ torch_dtype == torch .float16 or getattr (hf_quantizer , "use_keep_in_fp32_modules" , False )
4780+ ):
4781+ keep_in_fp32_modules .extend (model ._keep_in_fp32_modules )
4782+
4783+ if model ._keep_in_fp32_modules_strict is not None and (
4784+ torch_dtype == torch .float16 or torch_dtype == torch .bfloat16
47694785 ):
4786+ keep_in_fp32_modules .extend (model ._keep_in_fp32_modules_strict )
4787+
4788+ keep_in_fp32_regex = None
4789+ if keep_in_fp32_modules :
47704790 # We need to match exact layers, so we add either `.` on each side, or start/end of string
4771- keep_in_fp32_regex = re .compile (
4772- "|" .join ([rf"((^|\.){ module } ($|\.))" for module in model ._keep_in_fp32_modules ])
4773- )
4791+ keep_in_fp32_regex = re .compile ("|" .join ([rf"((^|\.){ module } ($|\.))" for module in keep_in_fp32_modules ]))
47744792
47754793 if hf_quantizer is not None :
47764794 hf_quantizer .preprocess_model (
0 commit comments