-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Deprecate predict_epsilon
#1393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
287a550
Adapt ddpm, ddpmsolver to prediction_type.
pcuenca 74d7273
Deprecate predict_epsilon in __init__.
pcuenca 006f8d3
Bring FlaxDDIMScheduler up to date with DDIMScheduler.
pcuenca 0fa4468
Set prediction_type as an ivar for consistency.
pcuenca bfe4e0e
Convert pipeline_ddpm
pcuenca a6709c3
Adapt tests.
pcuenca 20462d6
Adapt unconditional training script.
pcuenca e50a2f3
Adapt BitDiffusion example.
pcuenca e3adf8c
Add missing kwargs in dpmsolver_multistep
pcuenca 1a19afb
Ugly workaround to accept deprecated predict_epsilon when loading
pcuenca bc374f2
Merge remote-tracking branch 'origin/main' into deprecate-predict-eps…
pcuenca fc2828b
Merge remote-tracking branch 'origin/main' into deprecate-predict-eps…
pcuenca 86e7fb0
make style
pcuenca 0955073
Remove import no longer in use.
pcuenca 0d3e123
Apply suggestions from code review
pcuenca 56e9d10
Use config.prediction_type everywhere
pcuenca 01df570
Add a couple of Flax prediction type tests.
pcuenca 4a3cbe2
make style
pcuenca 760f2fb
Merge branch 'main' into deprecate-predict-epsilon
patrickvonplaten e6b1b29
fix register deprecated arg
patrickvonplaten File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,14 +70,14 @@ def __call__( | |
| generated images. | ||
| """ | ||
| message = ( | ||
| "Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" | ||
| " DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`." | ||
| "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" | ||
| " DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`." | ||
| ) | ||
| predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) | ||
|
|
||
| if predict_epsilon is not None: | ||
| new_config = dict(self.scheduler.config) | ||
| new_config["predict_epsilon"] = predict_epsilon | ||
| new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice! |
||
| self.scheduler._internal_dict = FrozenDict(new_config) | ||
|
|
||
| if generator is not None and generator.device.type != self.device.type and self.device.type != "mps": | ||
|
|
@@ -114,9 +114,7 @@ def __call__( | |
| model_output = self.unet(image, t).sample | ||
|
|
||
| # 2. compute previous image: x_t -> x_t-1 | ||
| image = self.scheduler.step( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch we don't need to pass it anymore! |
||
| model_output, t, image, generator=generator, predict_epsilon=predict_epsilon | ||
| ).prev_sample | ||
| image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample | ||
|
|
||
| image = (image / 2 + 0.5).clamp(0, 1) | ||
| image = image.cpu().permute(0, 2, 3, 1).numpy() | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's see if we can find a more general solution here