From 2242aa8ec9c2b59db79c19aaf6b7cd59b332fea5 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 8 Jun 2023 07:23:37 +0000 Subject: [PATCH 1/4] v1 --- src/transformers/trainer.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3231558dec18..e36161f81545 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2177,11 +2177,18 @@ def _load_best_model(self): logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME) + adapter_model_path = os.path.join(self.state.best_model_checkpoint, "adapter_model.bin") + model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model - if os.path.exists(best_model_path) or os.path.exists(best_safe_model_path): + if ( + os.path.exists(best_model_path) + or os.path.exists(best_safe_model_path) + or os.path.exists(adapter_model_path) + ): if self.is_deepspeed_enabled: deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint) else: + has_been_loaded = True if is_sagemaker_mp_enabled(): if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): # If the 'user_content.pt' file exists, load with the new smp api. @@ -2207,10 +2214,10 @@ def _load_best_model(self): self.accelerator, model, self.state.best_model_checkpoint ) else: - if hasattr(model, "base_model") and getattr(model.base_model, "is_8bit_serializable", False): - # If train base_8_bit_models using PEFT & LoRA, assume that adapter have been saved properly. + if is_peft_available() and isinstance(model, PeftModel): + # If train a model using PEFT & LoRA, assume that adapter have been saved properly. if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): - if os.path.exists(os.path.join(self.state.best_model_checkpoint, "adapter_model.bin")): + if os.path.exists(adapter_model_path): model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) # Load_adapter has no return value present, modify it when appropriate. from torch.nn.modules.module import _IncompatibleKeys @@ -2222,9 +2229,11 @@ def _load_best_model(self): "using `TrainerCallback` to save adapter_model.bin in corresponding folders, " "here are some examples https://github.com/huggingface/peft/issues/96" ) + has_been_loaded = False else: # We can't do pure 8bit training using transformers. logger.warning("Could not loading a quantized checkpoint.") + has_been_loaded = False else: # We load the model state dict on the CPU to avoid an OOM error. if self.args.save_safetensors and os.path.isfile(best_safe_model_path): @@ -2236,7 +2245,7 @@ def _load_best_model(self): # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 # which takes *args instead of **kwargs load_result = model.load_state_dict(state_dict, False) - if not is_sagemaker_mp_enabled(): + if not is_sagemaker_mp_enabled() and has_been_loaded: self._issue_warnings_after_load(load_result) elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): load_result = load_sharded_checkpoint( From 44be183cfad8977abf128a12d240eedaa98c20b2 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 8 Jun 2023 09:33:05 +0000 Subject: [PATCH 2/4] some refactor - add ST format as well --- src/transformers/trainer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e36161f81545..c335aafcfeb3 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2177,13 +2177,15 @@ def _load_best_model(self): logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME) - adapter_model_path = os.path.join(self.state.best_model_checkpoint, "adapter_model.bin") + best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, "adapter_model.bin") + best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, "adapter_model.safetensors") model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if ( os.path.exists(best_model_path) or os.path.exists(best_safe_model_path) - or os.path.exists(adapter_model_path) + or os.path.exists(best_adapter_model_path) + or os.path.exists(best_safe_adapter_model_path) ): if self.is_deepspeed_enabled: deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint) @@ -2217,7 +2219,7 @@ def _load_best_model(self): if is_peft_available() and isinstance(model, PeftModel): # If train a model using PEFT & LoRA, assume that adapter have been saved properly. if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): - if os.path.exists(adapter_model_path): + if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path): model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) # Load_adapter has no return value present, modify it when appropriate. from torch.nn.modules.module import _IncompatibleKeys From bf31c5e5f192e34b93eb1817e03fafdce204cef6 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 8 Jun 2023 10:10:27 +0000 Subject: [PATCH 3/4] fix --- src/transformers/trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c335aafcfeb3..4abed6a2222c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2233,8 +2233,7 @@ def _load_best_model(self): ) has_been_loaded = False else: - # We can't do pure 8bit training using transformers. - logger.warning("Could not loading a quantized checkpoint.") + logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed") has_been_loaded = False else: # We load the model state dict on the CPU to avoid an OOM error. From f01a06d37fb8d83c7b8bcc118e0f5a0d36ff6435 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 8 Jun 2023 12:47:18 +0000 Subject: [PATCH 4/4] add `ADAPTER_WEIGHTS_NAME` & `ADAPTER_SAFE_WEIGHTS_NAME` --- src/transformers/trainer.py | 8 +++++--- src/transformers/utils/__init__.py | 2 ++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 4abed6a2222c..d93e6b587de0 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -134,6 +134,8 @@ ) from .training_args import OptimizerNames, ParallelMode, TrainingArguments from .utils import ( + ADAPTER_SAFE_WEIGHTS_NAME, + ADAPTER_WEIGHTS_NAME, CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, @@ -2177,8 +2179,8 @@ def _load_best_model(self): logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME) - best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, "adapter_model.bin") - best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, "adapter_model.safetensors") + best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME) + best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME) model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if ( @@ -2228,7 +2230,7 @@ def _load_best_model(self): else: logger.warning( "The intermediate checkpoints of PEFT may not be saved correctly, " - "using `TrainerCallback` to save adapter_model.bin in corresponding folders, " + f"using `TrainerCallback` to save {ADAPTER_WEIGHTS_NAME} in corresponding folders, " "here are some examples https://github.com/huggingface/peft/issues/96" ) has_been_loaded = False diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 7169c7daf969..3aa1f8aeb926 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -177,6 +177,8 @@ WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" +ADAPTER_WEIGHTS_NAME = "adapter_model.bin" +ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors" TF2_WEIGHTS_NAME = "tf_model.h5" TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json" TF_WEIGHTS_NAME = "model.ckpt"