-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Deprecate predict_epsilon
#1393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
| prediction_type: str = "epsilon", | ||
| **kwargs, | ||
| ): | ||
| message = ( | ||
| "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" | ||
| " DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`." | ||
| ) | ||
| predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) | ||
| if predict_epsilon is not None: | ||
| prediction_type = "epsilon" if predict_epsilon else "sample" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't work with from_pretrained:
ddpm = DDPMScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="scheduler",
predict_epsilon=False,
)
assert ddpm.prediction_type == "sample" # AssertionErrorThe reason is that from_config extracts predict_epsilon as one unused_kwargs: https://github.com/huggingface/diffusers/blob/main/src/diffusers/configuration_utils.py#L195
How could we handle this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Let's correct this in a follow up PR ok?
schedulers using from_pretrained.
| if predict_epsilon is not None: | ||
| new_config = dict(self.scheduler.config) | ||
| new_config["predict_epsilon"] = predict_epsilon | ||
| new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
| model_output = self.unet(image, t).sample | ||
|
|
||
| # 2. compute previous image: x_t -> x_t-1 | ||
| image = self.scheduler.step( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch we don't need to pass it anymore!
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
Outdated
Show resolved
Hide resolved
| if predict_epsilon is not None: | ||
| prediction_type = "epsilon" if predict_epsilon else "sample" | ||
|
|
||
| self.prediction_type = prediction_type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we need to modify the frozen dict as shown here:
| new_config = dict(scheduler.config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten I don't think there's an internal dict yet in the __init__ function. I wrote this test to verify that it works as expected: 01df570#diff-bf88051e4d8be9ab9d5a5f24a0daac59c898ddb1fca8505728f49b3809cd3666R611-R620
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yeah you're right good point! Fixing this on main as well
Co-authored-by: Patrick von Platen <[email protected]>
| def test_inference_deprecated_predict_epsilon(self): | ||
| deprecate("remove this test", "0.10.0", "remove") | ||
| unet = self.dummy_uncond_unet | ||
| scheduler = DDPMScheduler(predict_epsilon=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is now failing because DDPMScheduler is initialized with the deprecated kwarg and then it's replaced with the default prediction_type in register_to_config. @patrickvonplaten should we address this in a followup PR?
|
Cool fixing the failing test now on "main" |
| if "dtype" in unused_kwargs: | ||
| init_dict["dtype"] = unused_kwargs.pop("dtype") | ||
|
|
||
| if "predict_epsilon" in unused_kwargs and "prediction_type" not in init_dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's see if we can find a more general solution here
* Adapt ddpm, ddpmsolver to prediction_type. * Deprecate predict_epsilon in __init__. * Bring FlaxDDIMScheduler up to date with DDIMScheduler. * Set prediction_type as an ivar for consistency. * Convert pipeline_ddpm * Adapt tests. * Adapt unconditional training script. * Adapt BitDiffusion example. * Add missing kwargs in dpmsolver_multistep * Ugly workaround to accept deprecated predict_epsilon when loading schedulers using from_pretrained. * make style * Remove import no longer in use. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * Use config.prediction_type everywhere * Add a couple of Flax prediction type tests. * make style * fix register deprecated arg Co-authored-by: Patrick von Platen <[email protected]>
* Adapt ddpm, ddpmsolver to prediction_type. * Deprecate predict_epsilon in __init__. * Bring FlaxDDIMScheduler up to date with DDIMScheduler. * Set prediction_type as an ivar for consistency. * Convert pipeline_ddpm * Adapt tests. * Adapt unconditional training script. * Adapt BitDiffusion example. * Add missing kwargs in dpmsolver_multistep * Ugly workaround to accept deprecated predict_epsilon when loading schedulers using from_pretrained. * make style * Remove import no longer in use. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * Use config.prediction_type everywhere * Add a couple of Flax prediction type tests. * make style * fix register deprecated arg Co-authored-by: Patrick von Platen <[email protected]>
No description provided.