Skip to content

Conversation

@kashif
Copy link
Contributor

@kashif kashif commented Nov 25, 2022

Fixes for MPS device.

I believe is a better fix that #1410

Testing with on main branch:

from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler


model_id = "stabilityai/stable-diffusion-2"
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler).to("mps")

prompt = "beautiful gaze"
image = pipe(prompt, height=768, width=768).images[0]    

Fixes for MPS device
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 25, 2022

The documentation is not available anymore as the PR was closed or merged.

@kashif
Copy link
Contributor Author

kashif commented Nov 25, 2022

@pcuenca I can also fix the warning:

diffusers/src/diffusers/schedulers/scheduling_euler_discrete.py:128: UserWarning: The operator 'aten::nonzero' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:11.)
  step_index = (self.timesteps == timestep).nonzero().item()

By replacing it with:

step_index = self.timesteps.tolist().index(timestep)

what do you think?

@pcuenca
Copy link
Member

pcuenca commented Nov 25, 2022

@pcuenca I can also fix the warning:

diffusers/src/diffusers/schedulers/scheduling_euler_discrete.py:128: UserWarning: The operator 'aten::nonzero' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:11.)
  step_index = (self.timesteps == timestep).nonzero().item()

By replacing it with:

step_index = self.timesteps.tolist().index(timestep)

what do you think?

I think it should be ok, let's do it and see what other people think.

It's nice to remove the warning, but that operation will still move the tensor to the CPU (just as the fallback implementation of nonzero would do).

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this!

@kashif
Copy link
Contributor Author

kashif commented Nov 25, 2022

@pcuenca at the moment with this scheduler with fp16 on mps is also working for me:

scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    revision="fp16",
    scheduler=scheduler
).to("mps")

@patrickvonplaten
Copy link
Contributor

@pcuenca I leave it up to you to merge :-)

@kashif kashif changed the title [MPS] call contiguous after permute [MPS] call contiguous after permute and fix nonzero warning Nov 25, 2022
@kashif
Copy link
Contributor Author

kashif commented Nov 25, 2022

@pcuenca let me double check this now on my cuda box

@kashif
Copy link
Contributor Author

kashif commented Nov 25, 2022

@pcuenca also works fine on "cuda"

@kashif kashif changed the title [MPS] call contiguous after permute and fix nonzero warning [MPS] call contiguous after permute Nov 25, 2022
@pcuenca pcuenca merged commit babfb8a into huggingface:main Nov 25, 2022
@kashif kashif deleted the mps-fix branch November 25, 2022 13:00
sliard pushed a commit to sliard/diffusers that referenced this pull request Dec 21, 2022
* call contiguous after permute

Fixes for MPS device

* Fix MPS UserWarning

* make style

* Revert "Fix MPS UserWarning"

This reverts commit b46c328.
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* call contiguous after permute

Fixes for MPS device

* Fix MPS UserWarning

* make style

* Revert "Fix MPS UserWarning"

This reverts commit b46c328.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants