-
Notifications
You must be signed in to change notification settings - Fork 6.5k
MPS schedulers: don't use float64 #1169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4f7ad41
2ef622d
eb52838
e246b67
c44f1b9
995b865
e808fd1
377d711
a218dc9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -151,7 +151,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic | |
| sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) | ||
| sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) | ||
| self.sigmas = torch.from_numpy(sigmas).to(device=device) | ||
| self.timesteps = torch.from_numpy(timesteps).to(device=device) | ||
| if str(device).startswith("mps"): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can it be
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Exactly, I found
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In other places we use |
||
| # mps does not support float64 | ||
| self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) | ||
| else: | ||
| self.timesteps = torch.from_numpy(timesteps).to(device=device) | ||
|
|
||
| def step( | ||
| self, | ||
|
|
@@ -217,8 +221,8 @@ def step( | |
|
|
||
| prev_sample = sample + derivative * dt | ||
|
|
||
| device = model_output.device if torch.is_tensor(model_output) else "cpu" | ||
| if str(device) == "mps": | ||
| device = model_output.device if torch.is_tensor(model_output) else torch.device("cpu") | ||
| if device.type == "mps": | ||
| # randn does not work reproducibly on mps | ||
| noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to( | ||
| device | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be simplified if we returned the timesteps tensor from the scheduler.