Skip to content

Conversation

@patil-suraj
Copy link
Contributor

@patil-suraj patil-suraj commented Nov 24, 2022

This PR adds StableDiffusionUpscalePipeline

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 24, 2022

The documentation is not available anymore as the PR was closed or merged.

@patil-suraj patil-suraj marked this pull request as ready for review November 25, 2022 11:19
Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

noise_level = torch.cat([noise_level] * 2) if do_classifier_free_guidance else noise_level

# 6. Prepare latent variables
height, width = image.shape[2:]
Copy link
Member

@pcuenca pcuenca Nov 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand this correctly, the latents have the same size as the input image, right? In Katherine's upscaler the low-res image was upscaled using bilinear interpolation and the latents were the size of the output image. Is this not happening here?

Ok, I was wrong. In Katherine's upscaler the latents were upscaled and provided as conditioning. Now we create latents the same size as the low-res image and the vae decodes the final result to upscale it.

unet: UNet2DConditionModel,
low_res_scheduler: DDPMScheduler,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
max_noise_level: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this has to have a default with the new "optional" pipeline config arguments - otherwise this breaks

optional_parameters = set({k for k, v in parameters.items() if v.default is True})

Suggested change
max_noise_level: int,
max_noise_level: int = 9

Not sure what a good default is here.

Overall I agree with @pcuenca feedback that the code is from_pretrained has become a bit too much of a black box / magic - but I don't really see a way around it. Overall, I'd like to strongly advertise against using optional arguments to the pipeline inits, but if it makes sense here ok for me!

We should/could maybe jump on a call in a bit to discuss this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aah, thanks!

Here we could default to the value for SD2 which is 350

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't 350 be too high for the upscaling pipeline?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but that high value is only used during training, here we set as indicator value which shouldn't be crossed.

def __call__(
self,
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the image be a latent?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, the unet is conditioned on the low image not latents.

@patil-suraj patil-suraj changed the title [wip] StableDiffusionUpscalePipeline StableDiffusionUpscalePipeline Nov 25, 2022
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we divide by two here as well?


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents with 0.18215->0.08333
def decode_latents(self, latents):
latents = 1 / 0.08333 * latents
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool good to merge for me!

@patil-suraj patil-suraj merged commit 9ec5084 into main Nov 25, 2022
@patil-suraj patil-suraj deleted the upscale-pipeline branch November 25, 2022 15:13
@averad averad mentioned this pull request Nov 25, 2022
2 tasks
sliard pushed a commit to sliard/diffusers that referenced this pull request Dec 21, 2022
* StableDiffusionUpscalePipeline

* fix a few things

* make it better

* fix image batching

* run vae in fp32

* fix docstr

* resize to mul of 64

* doc

* remove safety_checker

* add max_noise_level

* fix Copied

* begin tests

* slow tests

* default max_noise_level

* remove kwargs

* doc

* fix

* fix fast tests

* fix fast tests

* no sf

* don't offload vae

Co-authored-by: Patrick von Platen <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* StableDiffusionUpscalePipeline

* fix a few things

* make it better

* fix image batching

* run vae in fp32

* fix docstr

* resize to mul of 64

* doc

* remove safety_checker

* add max_noise_level

* fix Copied

* begin tests

* slow tests

* default max_noise_level

* remove kwargs

* doc

* fix

* fix fast tests

* fix fast tests

* no sf

* don't offload vae

Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants