-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Description
Currently, we have got two codepaths:
- For non-sharded checkpoints we do:
unexpected_keys = load_model_dict_into_meta( - For sharded checkpoints we do:
accelerate.load_checkpoint_and_dispatch(
And then for the (bnb) quantized checkpoints, we merge a sharded checkpoint:
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) |
Essentially, we shouldn't have to merge sharded checkpoints even if it's quantized.
This will also allow us to more generally use keep_module_in_fp32
for sharded checkpoints. Currently, we have this logic for casting a model (which is tested thoroughly):
elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules: |
When using load_model_dict_into_meta()
, we do consider keep_module_in_fp32
:
keep_in_fp32_modules=None, |
But since for sharded checkpoints, we use load_checkpoint_and_dispatch()
, there is no way to pass keep_module_in_fp32
:
https://huggingface.co/docs/accelerate/main/en/package_reference/big_modeling#accelerate.load_checkpoint_and_dispatch
As discussed with @SunMarc, it's better to uniformize this so that we don't have to maintain two different codepaths and rely completely on load_model_dict_into_meta()
. Marc has kindly agreed to open a PR to attempt this (this could be done in a series of PRs if needed). But I will join if any help is needed.
Sub-issues
Metadata
Metadata
Labels
Type
Projects
Status