-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[Reproduceability 1/3] Allow tensors to be generated on CPU #1902
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
70e3de6
b2d85ea
63353bf
a639448
80ba55a
1479424
f6db58b
c7a69e7
4ec71b0
b535251
575b74c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,7 +24,7 @@ | |
| from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel | ||
| from ...pipelines import DiffusionPipeline, ImagePipelineOutput | ||
| from ...schedulers import UnCLIPScheduler | ||
| from ...utils import is_accelerate_available, logging | ||
| from ...utils import is_accelerate_available, logging, torch_randn | ||
| from .text_proj import UnCLIPTextProjModel | ||
|
|
||
|
|
||
|
|
@@ -105,11 +105,7 @@ def __init__( | |
|
|
||
| def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): | ||
| if latents is None: | ||
| if device.type == "mps": | ||
| # randn does not work reproducibly on mps | ||
| latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) | ||
| else: | ||
| latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) | ||
| latents = torch_randn(shape, generator=generator, device=device, dtype=dtype) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very nice |
||
| else: | ||
| if latents.shape != shape: | ||
| raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| # Copyright 2022 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """ | ||
| PyTorch utilities: Utilities related to PyTorch | ||
| """ | ||
| from typing import List, Optional, Tuple, Union | ||
|
|
||
| from . import logging | ||
| from .import_utils import is_torch_available | ||
|
|
||
|
|
||
| if is_torch_available(): | ||
| import torch | ||
|
|
||
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
|
|
||
|
|
||
| def torch_randn( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My only concern here is that
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe just |
||
| shape: Union[Tuple, List], | ||
| generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, | ||
| device: Optional["torch.device"] = None, | ||
| dtype: Optional["torch.dtype"] = None, | ||
| ): | ||
| """This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When | ||
| passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor | ||
| will always be created on CPU. | ||
| """ | ||
| # device on which tensor is createad defaults to device | ||
patrickvonplaten marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| rand_device = device | ||
| batch_size = shape[0] | ||
|
|
||
| if generator is not None: | ||
| if generator.device != device and generator.device.type == "cpu": | ||
| rand_device = "cpu" | ||
| if device != "mps": | ||
| logger.info( | ||
| f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." | ||
| f" Generator will be created on 'cpu' and then moved to {device}. Note that one can probably" | ||
patrickvonplaten marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| f" slighly speed up this function by passing a generator that was created on the {device} device." | ||
| ) | ||
| elif generator.device.type != device.type and generator.device.type == "cuda": | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would add a comment here that we only allow cpu->cuda generation for reproducibility reasons, which is why the other way around is not supported / doesn't make sense.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I'll make a whole doc page about this in the follow-up PR :-) |
||
| raise ValueError(f"Cannot generate a {device} tensor from a generator of type {generator.device.type}.") | ||
|
|
||
| if isinstance(generator, list): | ||
| shape = (1,) + shape[1:] | ||
| latents = [ | ||
| torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size) | ||
| ] | ||
| latents = torch.cat(latents, dim=0).to(device) | ||
|
Comment on lines
+57
to
+60
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very cool to include per-item reproducibility too |
||
| else: | ||
| latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) | ||
|
|
||
| return latents | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New
black? 🤔There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, updated my black and then re-updated back to 22.8. We should maybe soon blackify the complete codebase once :-)