Skip to content

Commit 8930013

Browse files
authored
Fix Flax from_pt (#1436)
Fix Flax `from_pt`. It worked for models but not for pipelines. Accidentally broken in #1107.
1 parent 6c56f05 commit 8930013

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/diffusers/modeling_flax_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def from_pretrained(
332332
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
333333
raise EnvironmentError(
334334
f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
335-
" using `from_pt=True`."
335+
" using `from_pt=True`."
336336
)
337337
else:
338338
raise EnvironmentError(

src/diffusers/pipeline_flax_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
317317
allow_patterns = [os.path.join(k, "*") for k in folder_names]
318318
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
319319

320-
# make sure we don't download PyTorch weights
321-
ignore_patterns = "*.bin"
320+
# make sure we don't download PyTorch weights, unless when using from_pt
321+
ignore_patterns = "*.bin" if not from_pt else []
322322

323323
if cls != FlaxDiffusionPipeline:
324324
requested_pipeline_class = cls.__name__

0 commit comments

Comments
 (0)