Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/pipeline_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
"FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"],
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
"ProcessorMixin": ["save_pretrained", "from_pretrained"],
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
},
}

Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
"ProcessorMixin": ["save_pretrained", "from_pretrained"],
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
},
}

Expand Down Expand Up @@ -579,6 +581,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else:
# else we just import it from the library.
library = importlib.import_module(library_name)

class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
Expand Down
44 changes: 44 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,50 @@ def test_download_only_pytorch(self):
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
assert not any(f.endswith(".msgpack") for f in files)

def test_download_no_safety_checker(self):
prompt = "hello"
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
generator = torch.Generator(device=torch_device).manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images

pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
generator_2 = torch.Generator(device=torch_device).manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images

assert np.max(np.abs(out - out_2)) < 1e-3

def test_load_no_safety_checker_explicit_locally(self):
prompt = "hello"
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
generator = torch.Generator(device=torch_device).manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images

with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None)
generator_2 = torch.Generator(device=torch_device).manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images

assert np.max(np.abs(out - out_2)) < 1e-3

def test_load_no_safety_checker_default_locally(self):
prompt = "hello"
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
generator = torch.Generator(device=torch_device).manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images

with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname)
generator_2 = torch.Generator(device=torch_device).manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images

assert np.max(np.abs(out - out_2)) < 1e-3


class CustomPipelineTests(unittest.TestCase):
def test_load_custom_pipeline(self):
Expand Down