Skip to content

Conversation

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Nov 17, 2022

This PR does two things:

    1. It adds the first second order scheduler to diffusers without changing the logic of the scheduler / model API:
    • a. The loop is still defined by iterating over self.scheduler.timesteps
    • b. The scheduler step function still returns only the computed model values
    1. It fixes the progress bar of PNDM and also makes it compatible with 2nd order schedulers. Now PNDM shows 50 update steps instead of 51.

Note that this design might evolve in the future as discussed in #1308

self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])

timesteps = torch.from_numpy(timesteps)
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2), timesteps[-1:]])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj @pcuenca @anton-l - this just repeats timesteps for the second order

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's cool! But why do we need to append the final sigma again? We end up with three trailing zeros here, not two.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I think repeat_interleave was not compatible / efficient with mps and maybe onnx; let's just take a note to deal with that later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as Pedro about the final sigma.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True I should remove it

sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
sigmas = torch.from_numpy(sigmas).to(device=device)
self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj @pcuenca @anton-l - this just repeats sigmas for the second order


@property
def state_in_first_order(self):
return self.dt is None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simple property defining the mode of the scheduler

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 17, 2022

The documentation is not available anymore as the PR was closed or merged.

[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
step_index = (self.timesteps == timestep).nonzero().item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be adjusted to account for duplicate entries in timesteps? because == will match more than one thing and .item() complains about that.

Then the result is ambiguous, but I guess we can take the first and then add one if not state_in_first_order?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @keturn, see index_for_timestep in a comment above.

Suggested change
step_index = (self.timesteps == timestep).nonzero().item()
step_index = self.index_for_timestep(timestep)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hack to get the timestep index seems to become even harder to read and debug here. Should we maybe return the timestep index arguments at some point? (planting a thought for 1.0.0)

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great effort, thanks a lot! I find it much easier to discuss based on some code than on the ether, so it's very helpful.

I think it's a bit finicky to get it right with all the indexing and stuff, but the end result looks really understandable. I could get what was going in a first pass even though I didn't realize there were some details a bit off until I looked more carefully. I think this is acceptable for the "write-once, read-many" approach we want to achieve here.

As an exercise, I'll try to find some time to build the alternative example with the scheduler returning a tuple that I proposed in the other thread, and see how it compares to this.

self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])

timesteps = torch.from_numpy(timesteps)
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2), timesteps[-1:]])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's cool! But why do we need to append the final sigma again? We end up with three trailing zeros here, not two.

sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compute the index here instead of receiving as an argument?

Suggested change
"""
"""
step_index = self.index_for_timestep(timestep)

where index_for_timestep would be something like:

    def index_for_timestep(self, timestep):
        pos = -1 if self.state_in_first_order else 0
        return (self.timesteps == timestep).nonzero()[pos].item()

It's better to use a function because it's a bit non-trivial and step requires the index too.

An alternative would be to call scale_model_input from the scheduler, but that's an API breaking change. (Why didn't we do it like that?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Pedro, the scale_model_input in other schedulers like LMSDiscreteScheduler don't take step_index as input, so would be nice to follow the same API.

Also @pcuenca

An alternative would be to call scale_model_input from the scheduler, but that's an API breaking change. (Why didn't we do it like that?)

scale_model_input is not called from scheduler because the scaled input needs to be passed to the model, and in the first iteration the model call happens before the scheduler step.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry yeah this was a bug

self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])

timesteps = torch.from_numpy(timesteps)
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2), timesteps[-1:]])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I think repeat_interleave was not compatible / efficient with mps and maybe onnx; let's just take a note to deal with that later.

[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
step_index = (self.timesteps == timestep).nonzero().item()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @keturn, see index_for_timestep in a comment above.

Suggested change
step_index = (self.timesteps == timestep).nonzero().item()
step_index = self.index_for_timestep(timestep)

# currently only gamma=0 is supported. This usually works best anyways.
# We can support gamma in the future but then need to scale the timestep before
# passing it to the model which requires a change in API
gamma = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Undecided if it's better to just remove the gamma var (but leave the comment) or keep it for clarity.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really nice! Pretty much the same comments as @pcuenca about the step_index.
Also let's try to compare this with alternate approach that @pcuenca proposed and then we could decide the final API. Will try this more and see if I have any other comments.

Also, my main comment here is that the formula for prev_sample seems wrong here.

Instead of prev_sample = model_output + derivative * self.dt

it should be prev_sample = sample + derivative * self.dt as per the paper and k-diffusion.

sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Pedro, the scale_model_input in other schedulers like LMSDiscreteScheduler don't take step_index as input, so would be nice to follow the same API.

Also @pcuenca

An alternative would be to call scale_model_input from the scheduler, but that's an API breaking change. (Why didn't we do it like that?)

scale_model_input is not called from scheduler because the scaled input needs to be passed to the model, and in the first iteration the model call happens before the scheduler step.

self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])

timesteps = torch.from_numpy(timesteps)
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2), timesteps[-1:]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as Pedro about the final sigma.

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Submitted changes to make it work (I think).

@patrickvonplaten
Copy link
Contributor Author

Thanks a lot for the corrections @pcuenca @patil-suraj @keturn.

The PR as is now is functional and gives 1-to-1 the same results as k-diffusion:

#!/usr/bin/env python3
from diffusers import DiffusionPipeline, StableDiffusionPipeline, HeunDiscreteScheduler
import torch

seed = 33

pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion")
pipe = pipe.to("cuda")

prompt = "an astronaut riding a horse on mars"
pipe.set_sampler("sample_heun")
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]

image.save("./astronaut_heun_k_diffusion_comp.png")

pipe = StableDiffusionPipeline(**pipe.components)
pipe = pipe.to("cuda")
pipe.scheduler = HeunDiscreteScheduler.from_config(pipe.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]

image.save("./astronaut_heun_comp.png")

K Diffusion
astronaut_heun_k_diffusion_comp

This PR (Diffusers)
astronaut_heun_comp (1)

@patrickvonplaten patrickvonplaten merged commit 4c54519 into main Nov 28, 2022
@patrickvonplaten patrickvonplaten deleted the add_heun branch November 28, 2022 21:56
sliard pushed a commit to sliard/diffusers that referenced this pull request Dec 21, 2022
* Add heun

* Finish first version of heun

* remove bogus

* finish

* finish

* improve

* up

* up

* fix more

* change progress bar

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

* finish

* up

* up

* up
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Add heun

* Finish first version of heun

* remove bogus

* finish

* finish

* improve

* up

* up

* fix more

* change progress bar

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

* finish

* up

* up

* up
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants