Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/diffusers/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@


class OnnxRuntimeModel:
base_model_prefix = "onnx_model"

def __init__(self, model=None, **kwargs):
logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.")
self.model = model
Expand Down
21 changes: 20 additions & 1 deletion src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
from tqdm.auto import tqdm

from .configuration_utils import ConfigMixin
from .utils import DIFFUSERS_CACHE, BaseOutput, logging
from .modeling_utils import WEIGHTS_NAME
from .onnx_utils import ONNX_WEIGHTS_NAME
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, logging


INDEX_FILE = "diffusion_pytorch_model.bin"
Expand Down Expand Up @@ -285,6 +288,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path):
config_dict = cls.get_config_dict(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
)
# make sure we only download sub-folders and `diffusers` filenames
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]

# download all allow_patterns
cached_folder = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
Expand All @@ -293,6 +311,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
)
else:
cached_folder = pretrained_model_name_or_path
Expand Down
25 changes: 25 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import gc
import os
import random
import tempfile
import unittest
Expand Down Expand Up @@ -45,8 +46,11 @@
UNet2DModel,
VQModel,
)
from diffusers.modeling_utils import WEIGHTS_NAME
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device
from diffusers.utils import CONFIG_NAME
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer

Expand Down Expand Up @@ -707,6 +711,27 @@ def tearDown(self):
gc.collect()
torch.cuda.empty_cache()

def test_smart_download(self):
model_id = "hf-internal-testing/unet-pipeline-dummy"
with tempfile.TemporaryDirectory() as tmpdirname:
_ = DiffusionPipeline.from_pretrained(model_id, cache_dir=tmpdirname, force_download=True)
local_repo_name = "--".join(["models"] + model_id.split("/"))
snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots")
snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0])

# inspect all downloaded files to make sure that everything is included
assert os.path.isfile(os.path.join(snapshot_dir, DiffusionPipeline.config_name))
assert os.path.isfile(os.path.join(snapshot_dir, CONFIG_NAME))
assert os.path.isfile(os.path.join(snapshot_dir, SCHEDULER_CONFIG_NAME))
assert os.path.isfile(os.path.join(snapshot_dir, WEIGHTS_NAME))
assert os.path.isfile(os.path.join(snapshot_dir, "scheduler", SCHEDULER_CONFIG_NAME))
assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME))
assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME))
# let's make sure the super large numpy file:
# https://huggingface.co/hf-internal-testing/unet-pipeline-dummy/blob/main/big_array.npy
# is not downloaded, but all the expected ones
assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy"))
Comment on lines +715 to +733
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great test


@property
def dummy_safety_checker(self):
def check(images, *args, **kwargs):
Expand Down