-
Notifications
You must be signed in to change notification settings - Fork 31.1k
[Trainer] Correct behavior of _load_best_model for PEFT models
#24103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2177,11 +2177,18 @@ def _load_best_model(self): | |
| logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") | ||
| best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) | ||
| best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME) | ||
| adapter_model_path = os.path.join(self.state.best_model_checkpoint, "adapter_model.bin") | ||
|
|
||
| model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model | ||
| if os.path.exists(best_model_path) or os.path.exists(best_safe_model_path): | ||
| if ( | ||
| os.path.exists(best_model_path) | ||
| or os.path.exists(best_safe_model_path) | ||
| or os.path.exists(adapter_model_path) | ||
| ): | ||
| if self.is_deepspeed_enabled: | ||
| deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint) | ||
| else: | ||
| has_been_loaded = True | ||
| if is_sagemaker_mp_enabled(): | ||
| if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): | ||
| # If the 'user_content.pt' file exists, load with the new smp api. | ||
|
|
@@ -2207,10 +2214,10 @@ def _load_best_model(self): | |
| self.accelerator, model, self.state.best_model_checkpoint | ||
| ) | ||
| else: | ||
| if hasattr(model, "base_model") and getattr(model.base_model, "is_8bit_serializable", False): | ||
| # If train base_8_bit_models using PEFT & LoRA, assume that adapter have been saved properly. | ||
| if is_peft_available() and isinstance(model, PeftModel): | ||
| # If train a model using PEFT & LoRA, assume that adapter have been saved properly. | ||
| if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): | ||
| if os.path.exists(os.path.join(self.state.best_model_checkpoint, "adapter_model.bin")): | ||
| if os.path.exists(adapter_model_path): | ||
| model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) | ||
| # Load_adapter has no return value present, modify it when appropriate. | ||
| from torch.nn.modules.module import _IncompatibleKeys | ||
|
|
@@ -2222,9 +2229,11 @@ def _load_best_model(self): | |
| "using `TrainerCallback` to save adapter_model.bin in corresponding folders, " | ||
| "here are some examples https://github.com/huggingface/peft/issues/96" | ||
| ) | ||
| has_been_loaded = False | ||
| else: | ||
| # We can't do pure 8bit training using transformers. | ||
| logger.warning("Could not loading a quantized checkpoint.") | ||
| has_been_loaded = False | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be removed now?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is needed so that it can be used in the block below for the check, otherwise it will throw an error similar as #24096
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AH sorry I see what you meant, yes will remove it
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. proposed something in bf31c5e |
||
| else: | ||
| # We load the model state dict on the CPU to avoid an OOM error. | ||
| if self.args.save_safetensors and os.path.isfile(best_safe_model_path): | ||
|
|
@@ -2236,7 +2245,7 @@ def _load_best_model(self): | |
| # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 | ||
| # which takes *args instead of **kwargs | ||
| load_result = model.load_state_dict(state_dict, False) | ||
| if not is_sagemaker_mp_enabled(): | ||
| if not is_sagemaker_mp_enabled() and has_been_loaded: | ||
| self._issue_warnings_after_load(load_result) | ||
| elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): | ||
| load_result = load_sharded_checkpoint( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it can also be safetensor ckpts, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe adding
best_safe_adapter_model_pathshould serve the purpose?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perfect, will refactor that a bit