- 
        Couldn't load subscription status. 
- Fork 6.5k
[Single File] Add GGUF support #9964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
b5eeaa4
              71897b1
              89ea1ee
              f0bcd94
              60d1385
              22ed0b0
              2e6d340
              b5f927c
              b9666c7
              6dc5d22
              428e44b
              d7f09f2
              1649936
              28d3a64
              c34a451
              84493db
              50bd784
              8f604b3
              afd5d7d
              e1b964a
              0ed31bc
              af381ad
              52a1bcb
              66ae46e
              67f1700
              8abfa55
              d4b88d7
              30f13ed
              9310035
              e9303a0
              e56c266
              1209c3a
              db9b6f3
              4c0360a
              aa7659b
              78c7861
              33eb431
              9651ddc
              746fd2f
              e027d46
              9db2396
              7ee89f4
              edf3e54
              d3eb54f
              82606cb
              4f34f14
              090efdb
              391b5a9
              e67c25a
              e710bde
              f59e07a
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|  | @@ -17,8 +17,10 @@ | |||
| from contextlib import nullcontext | ||||
| from typing import Optional | ||||
|  | ||||
| import torch | ||||
| from huggingface_hub.utils import validate_hf_hub_args | ||||
|  | ||||
| from ..quantizers import DiffusersAutoQuantizer | ||||
| from ..utils import deprecate, is_accelerate_available, logging | ||||
| from .single_file_utils import ( | ||||
| SingleFileComponentError, | ||||
|  | @@ -202,6 +204,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = | |||
| subfolder = kwargs.pop("subfolder", None) | ||||
| revision = kwargs.pop("revision", None) | ||||
| torch_dtype = kwargs.pop("torch_dtype", None) | ||||
| quantization_config = kwargs.pop("quantization_config", None) | ||||
|  | ||||
| if isinstance(pretrained_model_link_or_path_or_dict, dict): | ||||
| checkpoint = pretrained_model_link_or_path_or_dict | ||||
|  | @@ -215,6 +218,12 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = | |||
| local_files_only=local_files_only, | ||||
| revision=revision, | ||||
| ) | ||||
| if quantization_config is not None: | ||||
| hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) | ||||
| hf_quantizer.validate_environment() | ||||
|  | ||||
| else: | ||||
| hf_quantizer = None | ||||
|  | ||||
| mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name] | ||||
|  | ||||
|  | @@ -295,8 +304,34 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = | |||
| with ctx(): | ||||
| model = cls.from_config(diffusers_model_config) | ||||
|  | ||||
| # Check if `_keep_in_fp32_modules` is not None | ||||
| use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( | ||||
|         
                  DN6 marked this conversation as resolved.
              Show resolved
            Hide resolved | ||||
| (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") | ||||
| ) | ||||
| if use_keep_in_fp32_modules: | ||||
| keep_in_fp32_modules = cls._keep_in_fp32_modules | ||||
| if not isinstance(keep_in_fp32_modules, list): | ||||
| keep_in_fp32_modules = [keep_in_fp32_modules] | ||||
|  | ||||
| else: | ||||
| keep_in_fp32_modules = [] | ||||
|  | ||||
| if hf_quantizer is not None: | ||||
| hf_quantizer.preprocess_model( | ||||
| model=model, | ||||
| device_map=None, | ||||
| state_dict=diffusers_format_checkpoint, | ||||
| keep_in_fp32_modules=keep_in_fp32_modules, | ||||
| ) | ||||
|  | ||||
| if is_accelerate_available(): | ||||
| unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) | ||||
| unexpected_keys = load_model_dict_into_meta( | ||||
| model, | ||||
| diffusers_format_checkpoint, | ||||
| dtype=torch_dtype, | ||||
| hf_quantizer=hf_quantizer, | ||||
| keep_in_fp32_modules=keep_in_fp32_modules, | ||||
| ) | ||||
|  | ||||
| else: | ||||
| _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) | ||||
|  | @@ -310,6 +345,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = | |||
| f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" | ||||
| ) | ||||
|  | ||||
| if hf_quantizer is not None: | ||||
| hf_quantizer.postprocess_model(model) | ||||
|  | ||||
| if torch_dtype is not None: | ||||
| model.to(torch_dtype) | ||||
|          | ||||
| # When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will | 
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|  | @@ -17,26 +17,30 @@ | |||
| import importlib | ||||
| import inspect | ||||
| import os | ||||
| from array import array | ||||
| from collections import OrderedDict | ||||
| from pathlib import Path | ||||
| from typing import List, Optional, Union | ||||
|  | ||||
| import safetensors | ||||
| import torch | ||||
| from huggingface_hub.utils import EntryNotFoundError | ||||
| from tqdm import tqdm | ||||
|  | ||||
| from ..quantizers.quantization_config import QuantizationMethod | ||||
| from ..utils import ( | ||||
| GGUF_FILE_EXTENSION, | ||||
| SAFE_WEIGHTS_INDEX_NAME, | ||||
| SAFETENSORS_FILE_EXTENSION, | ||||
| WEIGHTS_INDEX_NAME, | ||||
| _add_variant, | ||||
| _get_model_file, | ||||
| deprecate, | ||||
| is_accelerate_available, | ||||
| is_torch_available, | ||||
| is_torch_version, | ||||
| logging, | ||||
| ) | ||||
| from ..utils.import_utils import is_gguf_available | ||||
|          | ||||
|  | ||||
|  | ||||
| logger = logging.get_logger(__name__) | ||||
|  | @@ -140,6 +144,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ | |||
| file_extension = os.path.basename(checkpoint_file).split(".")[-1] | ||||
| if file_extension == SAFETENSORS_FILE_EXTENSION: | ||||
| return safetensors.torch.load_file(checkpoint_file, device="cpu") | ||||
| elif file_extension == GGUF_FILE_EXTENSION: | ||||
| return load_gguf_checkpoint(checkpoint_file) | ||||
| else: | ||||
| weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} | ||||
| return torch.load( | ||||
|  | @@ -176,11 +182,9 @@ def load_model_dict_into_meta( | |||
| hf_quantizer=None, | ||||
| keep_in_fp32_modules=None, | ||||
| ) -> List[str]: | ||||
| if hf_quantizer is None: | ||||
| device = device or torch.device("cpu") | ||||
| device = device or torch.device("cpu") | ||||
|          | ||||
| param_device = torch.cuda.current_device() | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@a-r-r-o-w has an open PR for this #10069
        
          
              
                  sayakpaul marked this conversation as resolved.
              
          
            Show resolved
            Hide resolved
        
              
          
              
                  sayakpaul marked this conversation as resolved.
              
          
            Show resolved
            Hide resolved
        
              
          
              
                Outdated
          
        
      There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for these two. Additionally, read_field() sounds a bit ambiguous -- could do with a better name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to check gguf version as well? (in addition to is_gguf_available)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. Let's always suggest installing the latest stable build of gguf like we do for bitsandbytes.
| if not is_bitsandbytes_available() or is_bitsandbytes_version("<", "0.43.3"): | 
        
          
              
                Outdated
          
        
      There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need a tqdm here? Not typical no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DN6 don't think it's typical of us to do q TQDM here no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could create a NON_TORCH_GGUF_DTYPE ENUM or SET with these two values (gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16) and use NON_TORCH_GGUF_DTYPE here, instead.
        
          
              
                Outdated
          
        
      There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trying to understand this check here,
I think maybe when we iterate through the tensors we also remove the names from the reader_keys as we go, the check here would make sense - but I didn't see any code to remove anything; so maybe we forgot to remove them? or it's not the case? did I miss something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copied this from transformers. But these aren't tensor keys. They're metadata keys
['GGUF.version', 'GGUF.tensor_count', 'GGUF.kv_count', 'general.architecture', 'general.quantization_version', 'general.file_type']
This can probably just be removed since the info isn't too relevant.
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -204,7 +204,10 @@ def create_quantized_param( | |
|  | ||
| module._parameters[tensor_name] = new_value | ||
|  | ||
| def check_quantized_param_shape(self, param_name, current_param_shape, loaded_param_shape): | ||
| def check_quantized_param_shape(self, param_name, current_param, loaded_param): | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. GGUF needs to access the tensor quant type to run a shape check. So this needs to change from passing in shapes to passing in params directly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not add this method to the  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see you're already adding this to the GGUF quantizer class. So, maybe okay to not modify this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. definitely make sense here to make sure this method has same signature across all quantizers, it will be confusing otherwise There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think no deprecation is fine since this method is called from  | ||
| current_param_shape = current_param.shape | ||
| loaded_param_shape = loaded_param.shape | ||
|  | ||
| n = current_param_shape.numel() | ||
| inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1) | ||
| if loaded_param_shape != inferred_shape: | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from .gguf_quantizer import GGUFQuantizer | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For GGUF files, I'm thinking if it would be nice to allow the user to load the model without having necessarily to specify
quantization_config=GGUFQuantizationConfig(compute_dtype=xxx). If we detect that this is a gguf, we can set by defaultquantization_config = GGUFQuantizationConfig(compute_dtype=torch.float32).I'm suggesting this because usually, when you pass a
quantization_config, it means either that the model is not quantized (bnb) or that the model is quantized (there is a quantization_config in the config.json) but we want to change a few arguments.Also, what happens when the user pass a gguf without specifying the
quantization_config?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is a good point! I think for most users, the entrypoint for GGUF files is going to be through
from_single_file()and I agree with the logic you mentioned.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that this is a nice convenience. GGUF does have all the information we need to auto fetch the config (honestly it's possible to skip the config all together), but it would mean that loading semantics would be different for GGUF vs other quant types. e.g.
GGUF
BnB and TorchAO (assuming these can be supported):
GGUF can also be used through
from_pretrained(assuming quants of diffusers format checkpoints show up as some point) and we would have to pass a quant config in that case. I understand it's not ideal, but I feel it's better to preserve consistency across the different quant loading methods.@SunMarc if the config isn't passed you get shape mismatch errors when you hit
load_model_dict_into_metasince the quant shapes are different from the expected shapes.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I thought about that too, but I think the API for
from_single_fileandfrom_pretrainedmight just have to be different. It is a bit confusing but I'm not sure if there is a way to make it consistent betweenfrom_single_fileandfrom_pertrained, if we also want to make sure the same API is consistent across different quant typesGGUF is a special case here because it has built-in config. Normally, for single-file it is just a checkpoint without config, so you will always have to pass a config (at least I think so, is it? @DN6 ). So for loading a regular quantized model (e.g. BNB) we can load it with
from_pretrainedwithout passing a config, but forfrom_single_file, we will have to manually pass a configso agree with @DN6 here I think it more important to make the same API (
from_pretrained APIorfrom_single_file) consistent for different quant types; if we have to choose onebut if there a way to make it consistent between from_pretrained and from_single_file and across all quant types it will be great!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, want to know this: do we plan to support quantizing a model in
from_single_file? @DN6There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to at least make the user aware when the passed config and the determined config mismatch and if that could lead to unintentional consequences?
Supporting quantizing in the GGUF format (regardless of
from_pretrained()orfrom_single_file()) would be reallllly nice.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yiyixuxu Yeah we can definitely support quantizing a model via single file. For GGUF I can look into in a follow up because we would have to port the quantize functions to torch (the gguf library uses numpy). We could use the gguf library interally to quantize but it's quite slow since we would have to move tensors off GPU, convert to numpy and then quantize.
I think with torch AO I'm pretty sure it would work just out of the box.
You would have to save it with
save_pretrainedthough since we don't support serializing single file checkpoints.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, what I am hearing is saving a GGUF quantized model would be added in a follow-up PR? That is also okay but it could be quite an enabling factor for the community.
I think the porting option is more preferrable.
You mean serializing with
torchaobut with quantization configs similar to the ones provided in GGUF?