-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Add 2nd order heun scheduler #1336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) | ||
|
|
||
| timesteps = torch.from_numpy(timesteps) | ||
| timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2), timesteps[-1:]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patil-suraj @pcuenca @anton-l - this just repeats timesteps for the second order
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's cool! But why do we need to append the final sigma again? We end up with three trailing zeros here, not two.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I think repeat_interleave was not compatible / efficient with mps and maybe onnx; let's just take a note to deal with that later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question as Pedro about the final sigma.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True I should remove it
| sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) | ||
| sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) | ||
| sigmas = torch.from_numpy(sigmas).to(device=device) | ||
| self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patil-suraj @pcuenca @anton-l - this just repeats sigmas for the second order
|
|
||
| @property | ||
| def state_in_first_order(self): | ||
| return self.dt is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simple property defining the mode of the scheduler
|
The documentation is not available anymore as the PR was closed or merged. |
| [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When | ||
| returning a tuple, the first element is the sample tensor. | ||
| """ | ||
| step_index = (self.timesteps == timestep).nonzero().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need to be adjusted to account for duplicate entries in timesteps? because == will match more than one thing and .item() complains about that.
Then the result is ambiguous, but I guess we can take the first and then add one if not state_in_first_order?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @keturn, see index_for_timestep in a comment above.
| step_index = (self.timesteps == timestep).nonzero().item() | |
| step_index = self.index_for_timestep(timestep) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hack to get the timestep index seems to become even harder to read and debug here. Should we maybe return the timestep index arguments at some point? (planting a thought for 1.0.0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great effort, thanks a lot! I find it much easier to discuss based on some code than on the ether, so it's very helpful.
I think it's a bit finicky to get it right with all the indexing and stuff, but the end result looks really understandable. I could get what was going in a first pass even though I didn't realize there were some details a bit off until I looked more carefully. I think this is acceptable for the "write-once, read-many" approach we want to achieve here.
As an exercise, I'll try to find some time to build the alternative example with the scheduler returning a tuple that I proposed in the other thread, and see how it compares to this.
| self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) | ||
|
|
||
| timesteps = torch.from_numpy(timesteps) | ||
| timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2), timesteps[-1:]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's cool! But why do we need to append the final sigma again? We end up with three trailing zeros here, not two.
| sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep | ||
| Returns: | ||
| `torch.FloatTensor`: scaled input sample | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compute the index here instead of receiving as an argument?
| """ | |
| """ | |
| step_index = self.index_for_timestep(timestep) |
where index_for_timestep would be something like:
def index_for_timestep(self, timestep):
pos = -1 if self.state_in_first_order else 0
return (self.timesteps == timestep).nonzero()[pos].item()It's better to use a function because it's a bit non-trivial and step requires the index too.
An alternative would be to call scale_model_input from the scheduler, but that's an API breaking change. (Why didn't we do it like that?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with Pedro, the scale_model_input in other schedulers like LMSDiscreteScheduler don't take step_index as input, so would be nice to follow the same API.
Also @pcuenca
An alternative would be to call scale_model_input from the scheduler, but that's an API breaking change. (Why didn't we do it like that?)
scale_model_input is not called from scheduler because the scaled input needs to be passed to the model, and in the first iteration the model call happens before the scheduler step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry yeah this was a bug
| self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) | ||
|
|
||
| timesteps = torch.from_numpy(timesteps) | ||
| timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2), timesteps[-1:]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I think repeat_interleave was not compatible / efficient with mps and maybe onnx; let's just take a note to deal with that later.
| [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When | ||
| returning a tuple, the first element is the sample tensor. | ||
| """ | ||
| step_index = (self.timesteps == timestep).nonzero().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @keturn, see index_for_timestep in a comment above.
| step_index = (self.timesteps == timestep).nonzero().item() | |
| step_index = self.index_for_timestep(timestep) |
| # currently only gamma=0 is supported. This usually works best anyways. | ||
| # We can support gamma in the future but then need to scale the timestep before | ||
| # passing it to the model which requires a change in API | ||
| gamma = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Undecided if it's better to just remove the gamma var (but leave the comment) or keep it for clarity.
patil-suraj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really nice! Pretty much the same comments as @pcuenca about the step_index.
Also let's try to compare this with alternate approach that @pcuenca proposed and then we could decide the final API. Will try this more and see if I have any other comments.
Also, my main comment here is that the formula for prev_sample seems wrong here.
Instead of prev_sample = model_output + derivative * self.dt
it should be prev_sample = sample + derivative * self.dt as per the paper and k-diffusion.
| sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep | ||
| Returns: | ||
| `torch.FloatTensor`: scaled input sample | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with Pedro, the scale_model_input in other schedulers like LMSDiscreteScheduler don't take step_index as input, so would be nice to follow the same API.
Also @pcuenca
An alternative would be to call scale_model_input from the scheduler, but that's an API breaking change. (Why didn't we do it like that?)
scale_model_input is not called from scheduler because the scaled input needs to be passed to the model, and in the first iteration the model call happens before the scheduler step.
| self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) | ||
|
|
||
| timesteps = torch.from_numpy(timesteps) | ||
| timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2), timesteps[-1:]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question as Pedro about the final sigma.
pcuenca
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Submitted changes to make it work (I think).
|
Thanks a lot for the corrections @pcuenca @patil-suraj @keturn. The PR as is now is functional and gives 1-to-1 the same results as k-diffusion: #!/usr/bin/env python3
from diffusers import DiffusionPipeline, StableDiffusionPipeline, HeunDiscreteScheduler
import torch
seed = 33
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion")
pipe = pipe.to("cuda")
prompt = "an astronaut riding a horse on mars"
pipe.set_sampler("sample_heun")
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
image.save("./astronaut_heun_k_diffusion_comp.png")
pipe = StableDiffusionPipeline(**pipe.components)
pipe = pipe.to("cuda")
pipe.scheduler = HeunDiscreteScheduler.from_config(pipe.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
image.save("./astronaut_heun_comp.png") |
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Outdated
Show resolved
Hide resolved
* Add heun * Finish first version of heun * remove bogus * finish * finish * improve * up * up * fix more * change progress bar * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py * finish * up * up * up
* Add heun * Finish first version of heun * remove bogus * finish * finish * improve * up * up * fix more * change progress bar * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py * finish * up * up * up


This PR does two things:
diffuserswithout changing the logic of the scheduler / model API:self.scheduler.timestepsNote that this design might evolve in the future as discussed in #1308