Skip to content

Conversation

@skirsten
Copy link
Contributor

@skirsten skirsten commented Dec 11, 2022

Hi, here some fixes and improvements for the Flax schedulers. Let me know what you think!
Sorry that its a huge PR with a single commit 😅

Refactor schedulers to be completely stateless (all the state is in the params)

  • No more state in the scheduler class
  • No more Implicit transfers
  • Extracted common state (common state can also be reused from other schedulers)
  • The shape and dtypes of the state returned by set_timestamp is now final and wont be changed by step (this reduces the amount of jit misses if jitting the scheduler separately)

Added dtype param to schedulers

Leave it at fp32 though unless you want to lose all details in the image.

Fix copy paste error in add_noise function

  • The add_noise function and thus img2img were not working in DDIM and DPMSolverMultistep
  • Extracted common logic so it can't happen again

Removed all jax conditionals to fix performance bottleneck

  • Using jax.lax.cond and jax.lax.switch causes the CPU to have to wait for the pred (even when jitted) causing the GPU pipeline to stall (not enough kernels scheduled). More info here.
  • If you notice that PNDM and DPMSolverMultistep were slower than DDIM, this was the reason.
  • Usually this is most noticeable on fast GPU + slow CPU combo or if running a splitkernel (separately jitted scheduler instead of the megakernel as in this repo).
  • Evaluating all branches instead has no noticeable performance impact.

Fixed small bugs and improvements

  • Added v_prediction where it was missing
  • Made DDPM jitable. Though I'm not sure sure if it works correctly.
  • Fixed DPMSolverMultistep not being able to start in the middle of a schedule. This caused img2img not to work.
  • Made LMSDiscrete run. Its not jitable and I always get back a black image though.
  • Probably some other stuff that I forgot about

Validation (outdated)

I messed up so Pytorch is fp16 and Flax is bf16

name Pytorch Flax v0.10.2 Flax this PR
DDIM torch_ddim_0 10 2 flax_ddim_0 10 2 flax_ddim_0 11 0 dev0
DPMSolverMultistep torch_dpmsolver_multistep_0 10 2 flax_dpmsolver_multistep_0 10 2 flax_dpmsolver_multistep_0 11 0 dev0
PNDM torch_pndm_0 10 2 flax_pndm_0 10 2 flax_pndm_0 11 0 dev0

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 11, 2022

The documentation is not available anymore as the PR was closed or merged.

@skirsten skirsten force-pushed the flax/stateless-schedulers-and-improvements branch 2 times, most recently from 3a64fbf to 0108066 Compare December 12, 2022 00:37
@pcuenca
Copy link
Member

pcuenca commented Dec 13, 2022

Hi @skirsten, this looks amazing! I see you are tweaking stuff, let me know when you want a review :)

@skirsten skirsten force-pushed the flax/stateless-schedulers-and-improvements branch from 3f48e8e to c6357ed Compare December 13, 2022 21:13
@skirsten
Copy link
Contributor Author

Hi @pcuenca, It should be ready for review now 😅

@pcuenca
Copy link
Member

pcuenca commented Dec 13, 2022

Awesome, will do this week!

@@ -0,0 +1,106 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we maybe add this to scheduling_utils_flax.py instead ? :-) We usually don't have _common_ files in src/diffusers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I moved it to scheduling_utils_flax.py

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks super nice! Thanks a lot for working on this @skirsten :-)

@pcuenca it'd be amazing if you could give this a try on a TPU

@skirsten skirsten force-pushed the flax/stateless-schedulers-and-improvements branch from c6357ed to 9315088 Compare December 19, 2022 12:57
Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! Much clearer code (and more efficient, according to the previous comments).

I tested PNDM, DDPM, DPM Solver and LMS Discrete on TPU v3-8 and these were my results:

  • DPM Solver produces identical results as the previous version.
  • There are some minor visual differences in PNDM. I haven't found the reason: tried to force dtype to float32 when computing the betas but it didn't make a difference.
  • LMS does not work in either version.
  • DDPM produces noise in the new version. It crashed in the previous one.

I think we should merge this as it's so much better. My approach would be:

  • Deal with DDPM and LMS in a followup PR, as they didn't work anyway in the previous implementation.
  • If we can easily find a reason for the minor discrepancies in PNDM, let's try to apply it. I already spent a couple hours and couldn't find it, so I wouldn't spend much more time.

What do you think?

Comment on lines +330 to +334
model_output = jax.lax.select(
(state.counter % 4) != 3,
model_output, # remainder 0, 1, 2
state.cur_model_output + 1 / 6 * model_output, # remainder 3
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes are all much cleaner this way. Thanks a lot!

Comment on lines +337 to +343
cur_model_output=jax.lax.select_n(
state.counter % 4,
state.cur_model_output + 1 / 6 * model_output, # remainder 0
state.cur_model_output + 1 / 3 * model_output, # remainder 1
state.cur_model_output + 1 / 3 * model_output, # remainder 2
jnp.zeros_like(state.cur_model_output), # remainder 3
),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@pcuenca
Copy link
Member

pcuenca commented Dec 19, 2022

Maybe @patil-suraj wants to take a quick look too.

@skirsten skirsten force-pushed the flax/stateless-schedulers-and-improvements branch from 9315088 to 752fb76 Compare December 19, 2022 19:37
@patrickvonplaten
Copy link
Contributor

Cool, let's merge as this is a clear improvement to what we had previously. More than happy to fix scheduler one-by-one in the future.

@patrickvonplaten patrickvonplaten merged commit f106ab4 into huggingface:main Dec 20, 2022
sliard pushed a commit to sliard/diffusers that referenced this pull request Dec 21, 2022
* [Flax] Stateless schedulers, fixes and refactors

* Remove scheduling_common_flax and some renames

* Update src/diffusers/schedulers/scheduling_pndm_flax.py

Co-authored-by: Pedro Cuenca <[email protected]>

Co-authored-by: Pedro Cuenca <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* [Flax] Stateless schedulers, fixes and refactors

* Remove scheduling_common_flax and some renames

* Update src/diffusers/schedulers/scheduling_pndm_flax.py

Co-authored-by: Pedro Cuenca <[email protected]>

Co-authored-by: Pedro Cuenca <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants