@@ -117,27 +117,12 @@ class ModelMixin(torch.nn.Module):
117117 Base class for all models.
118118
119119 [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
120- and saving models as well as a few methods common to all models to:
120+ and saving models.
121121
122- - resize the input embeddings,
123- - prune heads in the self-attention heads.
122+ Class attributes:
124123
125- Class attributes (overridden by derived classes):
126-
127- - **config_class** ([`ConfigMixin`]) -- A subclass of [`ConfigMixin`] to use as configuration class for this
128- model architecture.
129- - **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
130- taking as arguments:
131-
132- - **model** ([`ModelMixin`]) -- An instance of the model on which to load the TensorFlow checkpoint.
133- - **config** ([`PreTrainedConfigMixin`]) -- An instance of the configuration associated to the model.
134- - **path** (`str`) -- A path to the TensorFlow checkpoint.
135-
136- - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
137- classes of the same architecture adding modules on top of the base model.
138- - **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
139- - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
140- models, `pixel_values` for vision models and `input_values` for speech models).
124+ - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
125+ [`~modeling_utils.ModelMixin.save_pretrained`].
141126 """
142127 config_name = CONFIG_NAME
143128 _automatically_saved_args = ["_diffusers_version" , "_class_name" , "_name_or_path" ]
@@ -150,11 +135,10 @@ def save_pretrained(
150135 save_directory : Union [str , os .PathLike ],
151136 is_main_process : bool = True ,
152137 save_function : Callable = torch .save ,
153- ** kwargs ,
154138 ):
155139 """
156140 Save a model and its configuration file to a directory, so that it can be re-loaded using the
157- `[`~ModelMixin.from_pretrained`]` class method.
141+ `[`~modeling_utils. ModelMixin.from_pretrained`]` class method.
158142
159143 Arguments:
160144 save_directory (`str` or `os.PathLike`):
@@ -166,9 +150,6 @@ def save_pretrained(
166150 save_function (`Callable`):
167151 The function to use to save the state dictionary. Useful on distributed training like TPUs when one
168152 need to replace `torch.save` by another method.
169-
170- kwargs:
171- Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
172153 """
173154 if os .path .isfile (save_directory ):
174155 logger .error (f"Provided path ({ save_directory } ) should be a directory, not a file" )
@@ -224,34 +205,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
224205 - A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`],
225206 e.g., `./my_model_directory/`.
226207
227- config (`Union[ConfigMixin, str, os.PathLike]`, *optional*):
228- Can be either:
229-
230- - an instance of a class derived from [`ConfigMixin`],
231- - a string or path valid as input to [`~ConfigMixin.from_pretrained`].
232-
233- ConfigMixinuration for the model to use instead of an automatically loaded configuration.
234- ConfigMixinuration can be automatically loaded when:
235-
236- - The model is a model provided by the library (loaded with the *model id* string of a pretrained
237- model).
238- - The model was saved using [`~ModelMixin.save_pretrained`] and is reloaded by supplying the save
239- directory.
240- - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
241- configuration JSON file named *config.json* is found in the directory.
242208 cache_dir (`Union[str, os.PathLike]`, *optional*):
243209 Path to a directory in which a downloaded pretrained model configuration should be cached if the
244210 standard cache should not be used.
245- from_tf (`bool`, *optional*, defaults to `False`):
246- Load the model weights from a TensorFlow checkpoint save file (see docstring of
247- `pretrained_model_name_or_path` argument).
248- from_flax (`bool`, *optional*, defaults to `False`):
249- Load the model weights from a Flax checkpoint save file (see docstring of
250- `pretrained_model_name_or_path` argument).
251- ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
252- Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
253- as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
254- checkpoint with 3 labels).
211+ torch_dtype (`str` or `torch.dtype`, *optional*):
212+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
213+ will be automatically derived from the model's weights.
255214 force_download (`bool`, *optional*, defaults to `False`):
256215 Whether or not to force the (re-)download of the model weights and configuration files, overriding the
257216 cached versions if they exist.
@@ -267,7 +226,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
267226 Whether or not to only look at local files (i.e., do not try to download the model).
268227 use_auth_token (`str` or *bool*, *optional*):
269228 The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
270- when running `transformers -cli login` (stored in `~/.huggingface`).
229+ when running `diffusers -cli login` (stored in `~/.huggingface`).
271230 revision (`str`, *optional*, defaults to `"main"`):
272231 The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
273232 git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
@@ -278,18 +237,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
278237 Please refer to the mirror site for more information.
279238
280239 kwargs (remaining dictionary of keyword arguments, *optional*):
281- Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
282- `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
283- automatically loaded:
284-
285- - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
286- underlying model's `__init__` method (we assume all relevant updates to the configuration have
287- already been done)
288- - If a configuration is not provided, `kwargs` will be first passed to the configuration class
289- initialization function ([`~ConfigMixin.from_pretrained`]). Each key of `kwargs` that corresponds
290- to a configuration attribute will be used to override said attribute with the supplied `kwargs`
291- value. Remaining keys that do not correspond to any configuration attribute will be passed to the
292- underlying model's `__init__` function.
240+ Can be used to update the [`ConfigMixin`] of the model (after it being loaded).
293241
294242 <Tip>
295243
@@ -299,8 +247,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
299247
300248 <Tip>
301249
302- Activate the special ["offline-mode"](https://huggingface.co/transformers /installation.html#offline-mode) to
303- use this method in a firewalled environment.
250+ Activate the special ["offline-mode"](https://huggingface.co/diffusers /installation.html#offline-mode) to use
251+ this method in a firewalled environment.
304252
305253 </Tip>
306254
@@ -404,7 +352,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
404352 f" in the cached files and it looks like { pretrained_model_name_or_path } is not the path to a"
405353 f" directory containing a file named { WEIGHTS_NAME } or"
406354 " \n Checkout your internet connection or see how to run the library in"
407- " offline mode at 'https://huggingface.co/docs/transformers /installation#offline-mode'."
355+ " offline mode at 'https://huggingface.co/docs/diffusers /installation#offline-mode'."
408356 )
409357 except EnvironmentError :
410358 raise EnvironmentError (
0 commit comments