Skip to content

Conversation

@pcuenca
Copy link
Member

@pcuenca pcuenca commented Nov 24, 2022

No description provided.

@pcuenca pcuenca marked this pull request as draft November 24, 2022 12:35
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 24, 2022

The documentation is not available anymore as the PR was closed or merged.

Comment on lines 119 to 128
prediction_type: str = "epsilon",
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
prediction_type = "epsilon" if predict_epsilon else "sample"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work with from_pretrained:

ddpm = DDPMScheduler.from_pretrained(
    "hf-internal-testing/tiny-stable-diffusion-torch",
    subfolder="scheduler",
    predict_epsilon=False,
)
assert ddpm.prediction_type == "sample"        # AssertionError

The reason is that from_config extracts predict_epsilon as one unused_kwargs: https://github.com/huggingface/diffusers/blob/main/src/diffusers/configuration_utils.py#L195

How could we handle this case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Let's correct this in a follow up PR ok?

@pcuenca pcuenca marked this pull request as ready for review November 24, 2022 20:48
if predict_epsilon is not None:
new_config = dict(self.scheduler.config)
new_config["predict_epsilon"] = predict_epsilon
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

model_output = self.unet(image, t).sample

# 2. compute previous image: x_t -> x_t-1
image = self.scheduler.step(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch we don't need to pass it anymore!

if predict_epsilon is not None:
prediction_type = "epsilon" if predict_epsilon else "sample"

self.prediction_type = prediction_type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we need to modify the frozen dict as shown here:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten I don't think there's an internal dict yet in the __init__ function. I wrote this test to verify that it works as expected: 01df570#diff-bf88051e4d8be9ab9d5a5f24a0daac59c898ddb1fca8505728f49b3809cd3666R611-R620

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah you're right good point! Fixing this on main as well

def test_inference_deprecated_predict_epsilon(self):
deprecate("remove this test", "0.10.0", "remove")
unet = self.dummy_uncond_unet
scheduler = DDPMScheduler(predict_epsilon=False)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is now failing because DDPMScheduler is initialized with the deprecated kwarg and then it's replaced with the default prediction_type in register_to_config. @patrickvonplaten should we address this in a followup PR?

@patrickvonplaten
Copy link
Contributor

Cool fixing the failing test now on "main"

@patrickvonplaten patrickvonplaten merged commit d52388f into main Nov 25, 2022
@patrickvonplaten patrickvonplaten deleted the deprecate-predict-epsilon branch November 25, 2022 13:02
if "dtype" in unused_kwargs:
init_dict["dtype"] = unused_kwargs.pop("dtype")

if "predict_epsilon" in unused_kwargs and "prediction_type" not in init_dict:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's see if we can find a more general solution here

sliard pushed a commit to sliard/diffusers that referenced this pull request Dec 21, 2022
* Adapt ddpm, ddpmsolver to prediction_type.

* Deprecate predict_epsilon in __init__.

* Bring FlaxDDIMScheduler up to date with DDIMScheduler.

* Set prediction_type as an ivar for consistency.

* Convert pipeline_ddpm

* Adapt tests.

* Adapt unconditional training script.

* Adapt BitDiffusion example.

* Add missing kwargs in dpmsolver_multistep

* Ugly workaround to accept deprecated predict_epsilon when loading
schedulers using from_pretrained.

* make style

* Remove import no longer in use.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* Use config.prediction_type everywhere

* Add a couple of Flax prediction type tests.

* make style

* fix register deprecated arg

Co-authored-by: Patrick von Platen <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Adapt ddpm, ddpmsolver to prediction_type.

* Deprecate predict_epsilon in __init__.

* Bring FlaxDDIMScheduler up to date with DDIMScheduler.

* Set prediction_type as an ivar for consistency.

* Convert pipeline_ddpm

* Adapt tests.

* Adapt unconditional training script.

* Adapt BitDiffusion example.

* Add missing kwargs in dpmsolver_multistep

* Ugly workaround to accept deprecated predict_epsilon when loading
schedulers using from_pretrained.

* make style

* Remove import no longer in use.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* Use config.prediction_type everywhere

* Add a couple of Flax prediction type tests.

* make style

* fix register deprecated arg

Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants