Skip to content

Conversation

lawrence-cj
Copy link
Contributor

What does this PR do?

Add Sa-Solver into diffusers lib.
Cc: @patil-suraj. Really need your help to test it.

@sayakpaul
Copy link
Member

Looping in @yiyixuxu here.

@sayakpaul sayakpaul requested a review from yiyixuxu November 29, 2023 10:12
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@sayakpaul
Copy link
Member

@lawrence-cj let's also add a testing suite here. You can refer to a test suite for schedulers here:
https://github.com/huggingface/diffusers/tree/main/tests/schedulers

@yiyixuxu
Copy link
Collaborator

awesome :)
Can we see an example of how to use it with SD1.5 and SDXL?

@lawrence-cj
Copy link
Contributor Author

No problem, I will add a test for you.


if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a step_index counter instead of searching for the index at each step?
Like how it is done here for DPM

Essentially we only need to search for the step_index once in the beginning, and then just increase the counter after each step

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We implemented sa-solver according to DPMSolverMultistep in diffusers==0.21.2. Could you please help us to change it to the newest implementation?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As stated by @yiyixuxu, we need to change the logic to use a self.step_index counter or else this scheduler won't function correctly when using karras sigmas or when doing img2img translation. Could you take a look at

and refactor the code accordingly @lawrence-cj ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cc: @scxue

result_mean = torch.mean(torch.abs(sample))

if torch_device in ["mps"]:
assert abs(result_sum.item() - 176.66974135742188) < 1e-2
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have no idea how to get the specific value here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we can remove the mps check for now as well

assert abs(result_sum.item() - 167.47821044921875) < 1e-2
assert abs(result_mean.item() - 0.2178705964565277) < 1e-3
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 171.59352111816406) < 1e-2
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about this line? Wondering how to get the specific value 171.59352111816406, etc.

@sayakpaul
Copy link
Member

There are some tests failing here. Could we please fix them?

@sayakpaul sayakpaul requested a review from yiyixuxu December 7, 2023 04:00
@sayakpaul
Copy link
Member

I will let @yiyixuxu review this once the test failures are resolved.

@patrickvonplaten
Copy link
Contributor

Gentle ping here @yiyixuxu

@lawrence-cj
Copy link
Contributor Author

We add use_karras_sigmas in the sa-solver in the newest commit. Could you please help to check the correctness. @yiyixuxu

@lawrence-cj
Copy link
Contributor Author

lawrence-cj commented Jan 15, 2024

Gentle ping here. @yiyixuxu

@patrickvonplaten
Copy link
Contributor

Happy to merge whenever @yiyixuxu is happy here!

@lawrence-cj
Copy link
Contributor Author

lawrence-cj commented Jan 21, 2024

Thx so much for your review. I'm wondering how you make this style automatically. : ) Any tools for it? @yiyixuxu

@yiyixuxu
Copy link
Collaborator

@lawrence-cj if you installed our dev environment, you could run make style command :)
see our guide here https://huggingface.co/docs/diffusers/conceptual/contribution#how-to-open-a-pr

@lawrence-cj
Copy link
Contributor Author

Oh, thank you so much for the information. @yiyixuxu

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@yiyixuxu
Copy link
Collaborator

@lawrence-cj
thanks for adding this scheduler! I'm going to merge this in now.
we recently merged in this PR #6477 that improves the generation quality for DPM schedulers + SDXL - would same fix works for Sa-Solver? we can open an PR later to add same fix if this is the case!

@yiyixuxu yiyixuxu merged commit c7df846 into huggingface:main Jan 22, 2024
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* add Sa-Solver



---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: scxue <[email protected]>
Co-authored-by: jschen <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants