-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[Flax] Stateless schedulers, fixes and refactors #1661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flax] Stateless schedulers, fixes and refactors #1661
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
3a64fbf to
0108066
Compare
|
Hi @skirsten, this looks amazing! I see you are tweaking stuff, let me know when you want a review :) |
3f48e8e to
c6357ed
Compare
|
Hi @pcuenca, It should be ready for review now 😅 |
|
Awesome, will do this week! |
| @@ -0,0 +1,106 @@ | |||
| # Copyright 2022 The HuggingFace Team. All rights reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we maybe add this to scheduling_utils_flax.py instead ? :-) We usually don't have _common_ files in src/diffusers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. I moved it to scheduling_utils_flax.py
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
c6357ed to
9315088
Compare
pcuenca
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! Much clearer code (and more efficient, according to the previous comments).
I tested PNDM, DDPM, DPM Solver and LMS Discrete on TPU v3-8 and these were my results:
- DPM Solver produces identical results as the previous version.
- There are some minor visual differences in PNDM. I haven't found the reason: tried to force
dtypetofloat32when computing the betas but it didn't make a difference. - LMS does not work in either version.
- DDPM produces noise in the new version. It crashed in the previous one.
I think we should merge this as it's so much better. My approach would be:
- Deal with DDPM and LMS in a followup PR, as they didn't work anyway in the previous implementation.
- If we can easily find a reason for the minor discrepancies in PNDM, let's try to apply it. I already spent a couple hours and couldn't find it, so I wouldn't spend much more time.
What do you think?
| model_output = jax.lax.select( | ||
| (state.counter % 4) != 3, | ||
| model_output, # remainder 0, 1, 2 | ||
| state.cur_model_output + 1 / 6 * model_output, # remainder 3 | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes are all much cleaner this way. Thanks a lot!
| cur_model_output=jax.lax.select_n( | ||
| state.counter % 4, | ||
| state.cur_model_output + 1 / 6 * model_output, # remainder 0 | ||
| state.cur_model_output + 1 / 3 * model_output, # remainder 1 | ||
| state.cur_model_output + 1 / 3 * model_output, # remainder 2 | ||
| jnp.zeros_like(state.cur_model_output), # remainder 3 | ||
| ), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
|
Maybe @patil-suraj wants to take a quick look too. |
9315088 to
752fb76
Compare
Co-authored-by: Pedro Cuenca <[email protected]>
|
Cool, let's merge as this is a clear improvement to what we had previously. More than happy to fix scheduler one-by-one in the future. |
* [Flax] Stateless schedulers, fixes and refactors * Remove scheduling_common_flax and some renames * Update src/diffusers/schedulers/scheduling_pndm_flax.py Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
* [Flax] Stateless schedulers, fixes and refactors * Remove scheduling_common_flax and some renames * Update src/diffusers/schedulers/scheduling_pndm_flax.py Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
Hi, here some fixes and improvements for the Flax schedulers. Let me know what you think!
Sorry that its a huge PR with a single commit 😅
Refactor schedulers to be completely stateless (all the state is in the params)
set_timestampis now final and wont be changed bystep(this reduces the amount of jit misses if jitting the scheduler separately)Added dtype param to schedulers
Leave it at fp32 though unless you want to lose all details in the image.
Fix copy paste error in
add_noisefunctionadd_noisefunction and thus img2img were not working inDDIMandDPMSolverMultistepRemoved all jax conditionals to fix performance bottleneck
jax.lax.condandjax.lax.switchcauses the CPU to have to wait for thepred(even when jitted) causing the GPU pipeline to stall (not enough kernels scheduled). More info here.PNDMandDPMSolverMultistepwere slower thanDDIM, this was the reason.Fixed small bugs and improvements
v_predictionwhere it was missingDDPMjitable. Though I'm not sure sure if it works correctly.DPMSolverMultistepnot being able to start in the middle of a schedule. This caused img2img not to work.LMSDiscreterun. Its not jitable and I always get back a black image though.Validation(outdated)I messed up so Pytorch is fp16 and Flax is bf16
DDIMDPMSolverMultistepPNDM