-
Notifications
You must be signed in to change notification settings - Fork 6.5k
MPS schedulers: don't use float64 #1169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
| self.scheduler.set_timesteps(num_inference_steps, device=self.device) | ||
| timesteps_tensor = self.scheduler.timesteps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be simplified if we returned the timesteps tensor from the scheduler.
Not compatible with float16.
| sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) | ||
| self.sigmas = torch.from_numpy(sigmas).to(device=device) | ||
| self.timesteps = torch.from_numpy(timesteps).to(device=device) | ||
| if str(device).startswith("mps"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can it be mps:1, is that why there's a startswith? 🤯
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly, I found mps:0 when testing on my computer :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In other places we use device.type == "mps", but the method signature allows strings too.
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for fixing!
Just curious did you try out moving everything to float32? The weights of the models are always in float32 so I'd assume during the forward pass everything is downcasted to float32 anyways.
Let's maybe leave it as is for now and once we have finished our whole test migration to high precision few steps pipelines we could maybe revisit this PR and possibly just always do float32?
I thought about it but didn't want to change anything more than necessary. I agree we should verify it :)
Agreed! |
patil-suraj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks a lot for fixing this!
* Schedulers: don't use float64 on mps * Test set_timesteps() on device (float schedulers). * SD pipeline: use device in set_timesteps. * SD in-painting pipeline: use device in set_timesteps. * Tests: fix mps crashes. * Skip test_load_pipeline_from_git on mps. Not compatible with float16. * Use device.type instead of str in Euler schedulers.
No description provided.