134134)
135135from .training_args import OptimizerNames , ParallelMode , TrainingArguments
136136from .utils import (
137+ ADAPTER_SAFE_WEIGHTS_NAME ,
138+ ADAPTER_WEIGHTS_NAME ,
137139 CONFIG_NAME ,
138140 SAFE_WEIGHTS_INDEX_NAME ,
139141 SAFE_WEIGHTS_NAME ,
@@ -2177,11 +2179,20 @@ def _load_best_model(self):
21772179 logger .info (f"Loading best model from { self .state .best_model_checkpoint } (score: { self .state .best_metric } )." )
21782180 best_model_path = os .path .join (self .state .best_model_checkpoint , WEIGHTS_NAME )
21792181 best_safe_model_path = os .path .join (self .state .best_model_checkpoint , SAFE_WEIGHTS_NAME )
2182+ best_adapter_model_path = os .path .join (self .state .best_model_checkpoint , ADAPTER_WEIGHTS_NAME )
2183+ best_safe_adapter_model_path = os .path .join (self .state .best_model_checkpoint , ADAPTER_SAFE_WEIGHTS_NAME )
2184+
21802185 model = self .model_wrapped if is_sagemaker_mp_enabled () else self .model
2181- if os .path .exists (best_model_path ) or os .path .exists (best_safe_model_path ):
2186+ if (
2187+ os .path .exists (best_model_path )
2188+ or os .path .exists (best_safe_model_path )
2189+ or os .path .exists (best_adapter_model_path )
2190+ or os .path .exists (best_safe_adapter_model_path )
2191+ ):
21822192 if self .is_deepspeed_enabled :
21832193 deepspeed_load_checkpoint (self .model_wrapped , self .state .best_model_checkpoint )
21842194 else :
2195+ has_been_loaded = True
21852196 if is_sagemaker_mp_enabled ():
21862197 if os .path .isfile (os .path .join (self .state .best_model_checkpoint , "user_content.pt" )):
21872198 # If the 'user_content.pt' file exists, load with the new smp api.
@@ -2207,10 +2218,10 @@ def _load_best_model(self):
22072218 self .accelerator , model , self .state .best_model_checkpoint
22082219 )
22092220 else :
2210- if hasattr ( model , "base_model" ) and getattr (model . base_model , "is_8bit_serializable" , False ):
2211- # If train base_8_bit_models using PEFT & LoRA, assume that adapter have been saved properly.
2221+ if is_peft_available ( ) and isinstance (model , PeftModel ):
2222+ # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
22122223 if hasattr (model , "active_adapter" ) and hasattr (model , "load_adapter" ):
2213- if os .path .exists (os .path .join ( self . state . best_model_checkpoint , "adapter_model.bin" ) ):
2224+ if os .path .exists (best_adapter_model_path ) or os .path .exists ( best_safe_adapter_model_path ):
22142225 model .load_adapter (self .state .best_model_checkpoint , model .active_adapter )
22152226 # Load_adapter has no return value present, modify it when appropriate.
22162227 from torch .nn .modules .module import _IncompatibleKeys
@@ -2219,12 +2230,13 @@ def _load_best_model(self):
22192230 else :
22202231 logger .warning (
22212232 "The intermediate checkpoints of PEFT may not be saved correctly, "
2222- "using `TrainerCallback` to save adapter_model.bin in corresponding folders, "
2233+ f "using `TrainerCallback` to save { ADAPTER_WEIGHTS_NAME } in corresponding folders, "
22232234 "here are some examples https://github.com/huggingface/peft/issues/96"
22242235 )
2236+ has_been_loaded = False
22252237 else :
2226- # We can't do pure 8bit training using transformers.
2227- logger . warning ( "Could not loading a quantized checkpoint." )
2238+ logger . warning ( "Could not load adapter model, make sure to have `peft>=0.3.0` installed" )
2239+ has_been_loaded = False
22282240 else :
22292241 # We load the model state dict on the CPU to avoid an OOM error.
22302242 if self .args .save_safetensors and os .path .isfile (best_safe_model_path ):
@@ -2236,7 +2248,7 @@ def _load_best_model(self):
22362248 # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
22372249 # which takes *args instead of **kwargs
22382250 load_result = model .load_state_dict (state_dict , False )
2239- if not is_sagemaker_mp_enabled ():
2251+ if not is_sagemaker_mp_enabled () and has_been_loaded :
22402252 self ._issue_warnings_after_load (load_result )
22412253 elif os .path .exists (os .path .join (self .state .best_model_checkpoint , WEIGHTS_INDEX_NAME )):
22422254 load_result = load_sharded_checkpoint (
0 commit comments