Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def __getitem__(self, i):

if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = (
(h, w,) = (
img.shape[0],
img.shape[1],
)
Expand Down
2 changes: 1 addition & 1 deletion examples/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def __getitem__(self, i):

if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = (
(h, w,) = (
img.shape[0],
img.shape[1],
)
Expand Down
2 changes: 1 addition & 1 deletion examples/textual_inversion/textual_inversion_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def __getitem__(self, i):

if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = (
(h, w,) = (
img.shape[0],
img.shape[1],
)
Expand Down
1 change: 1 addition & 0 deletions scripts/convert_kakao_brain_unclip_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ def super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint(model,

# unet utils


# <original>.time_embed -> <diffusers>.time_embedding
Comment on lines 566 to 568
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New black? 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, updated my black and then re-updated back to 22.8. We should maybe soon blackify the complete codebase once :-)

def unet_time_embeddings(checkpoint, original_unet_prefix):
diffusers_checkpoint = {}
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def rename_key(key):
# PyTorch => Flax #
#####################


# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
Expand Down
8 changes: 2 additions & 6 deletions src/diffusers/pipelines/unclip/pipeline_unclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import UnCLIPScheduler
from ...utils import is_accelerate_available, logging
from ...utils import is_accelerate_available, logging, torch_randn
from .text_proj import UnCLIPTextProjModel


Expand Down Expand Up @@ -105,11 +105,7 @@ def __init__(

def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
latents = torch_randn(shape, generator=generator, device=device, dtype=dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice

else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ...models import UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import UnCLIPScheduler
from ...utils import is_accelerate_available, logging
from ...utils import is_accelerate_available, logging, torch_randn
from .text_proj import UnCLIPTextProjModel


Expand Down Expand Up @@ -113,11 +113,7 @@ def __init__(
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
latents = torch_randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
Expand Down
14 changes: 4 additions & 10 deletions src/diffusers/schedulers/scheduling_unclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, torch_randn
from .scheduling_utils import SchedulerMixin


Expand Down Expand Up @@ -273,15 +273,9 @@ def step(
# 6. Add noise
variance = 0
if t > 0:
device = model_output.device
if device.type == "mps":
# randn does not work reproducibly on mps
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
variance_noise = variance_noise.to(device)
else:
variance_noise = torch.randn(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
variance_noise = torch_randn(
model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device
)

variance = self._get_variance(
t,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from .logging import get_logger
from .outputs import BaseOutput
from .pil_utils import PIL_INTERPOLATION
from .torch_utils import torch_randn


if is_torch_available():
Expand Down
64 changes: 64 additions & 0 deletions src/diffusers/utils/torch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PyTorch utilities: Utilities related to PyTorch
"""
from typing import List, Optional, Tuple, Union

from . import logging
from .import_utils import is_torch_available


if is_torch_available():
import torch

logger = logging.get_logger(__name__) # pylint: disable=invalid-name


def torch_randn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only concern here is that torch_randn is easy to confuse (both visually and inadvertently while typing) with torch.randn. Would it make sense to make the name slightly more different? Can't think of anything great though, diffusers_randn feels kind of ugly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just randn_tensor?

shape: Union[Tuple, List],
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
device: Optional["torch.device"] = None,
dtype: Optional["torch.dtype"] = None,
):
"""This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When
passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor
will always be created on CPU.
"""
# device on which tensor is createad defaults to device
rand_device = device
batch_size = shape[0]

if generator is not None:
if generator.device != device and generator.device.type == "cpu":
rand_device = "cpu"
if device != "mps":
logger.info(
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
f" Generator will be created on 'cpu' and then moved to {device}. Note that one can probably"
f" slighly speed up this function by passing a generator that was created on the {device} device."
)
elif generator.device.type != device.type and generator.device.type == "cuda":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a comment here that we only allow cpu->cuda generation for reproducibility reasons, which is why the other way around is not supported / doesn't make sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll make a whole doc page about this in the follow-up PR :-)

raise ValueError(f"Cannot generate a {device} tensor from a generator of type {generator.device.type}.")

if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
Comment on lines +57 to +60
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool to include per-item reproducibility too

else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)

return latents
4 changes: 3 additions & 1 deletion tests/pipelines/unclip/test_unclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def test_unclip_karlo(self):
pipeline = pipeline.to(torch_device)
pipeline.set_progress_bar_config(disable=None)

generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipeline(
"horse",
num_images_per_prompt=1,
Expand All @@ -392,6 +392,8 @@ def test_unclip_karlo(self):

image = output.images[0]

np.save("/home/patrick_huggingface_co/diffusers-images/unclip/karlo_v1_alpha_horse_fp16.npy", image)

assert image.shape == (256, 256, 3)
assert np.abs(expected_image - image).max() < 1e-2

Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/unclip/test_unclip_image_variation.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def test_unclip_image_variation_karlo(self):
pipeline.set_progress_bar_config(disable=None)
pipeline.enable_sequential_cpu_offload()

generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipeline(
input_image,
num_images_per_prompt=1,
Expand Down
2 changes: 2 additions & 0 deletions utils/custom_init_isort.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def _inner(x):

def sort_objects(objects, key=None):
"Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str."

# If no key is provided, we use a noop.
def noop(x):
return x
Expand All @@ -117,6 +118,7 @@ def sort_objects_in_import(import_statement):
"""
Return the same `import_statement` but with objects properly sorted.
"""

# This inner function sort imports between [ ].
def _replace(match):
imports = match.groups()[0]
Expand Down