Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,9 @@ def __call__(
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)

# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps_tensor = self.scheduler.timesteps.to(self.device)
# set timesteps and move to the correct device
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
timesteps_tensor = self.scheduler.timesteps
Comment on lines +364 to +365
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be simplified if we returned the timesteps tensor from the scheduler.


# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,12 +416,9 @@ def __call__(
" `pipeline.unet` or your `mask_image` or `image` input."
)

# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps_tensor = self.scheduler.timesteps.to(self.device)
# set timesteps and move to the correct device
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
timesteps_tensor = self.scheduler.timesteps

# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
Expand Down
10 changes: 7 additions & 3 deletions src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
if str(device).startswith("mps"):
Copy link
Member

@anton-l anton-l Nov 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be mps:1, is that why there's a startswith? 🤯

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, I found mps:0 when testing on my computer :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other places we use device.type == "mps", but the method signature allows strings too.

# mps does not support float64
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
self.timesteps = torch.from_numpy(timesteps).to(device=device)

def step(
self,
Expand Down Expand Up @@ -217,8 +221,8 @@ def step(

prev_sample = sample + derivative * dt

device = model_output.device if torch.is_tensor(model_output) else "cpu"
if str(device) == "mps":
device = model_output.device if torch.is_tensor(model_output) else torch.device("cpu")
if device.type == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device
Expand Down
10 changes: 7 additions & 3 deletions src/diffusers/schedulers/scheduling_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
self.timesteps = torch.from_numpy(timesteps).to(device=device)

def step(
self,
Expand Down Expand Up @@ -214,8 +218,8 @@ def step(

gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0

device = model_output.device if torch.is_tensor(model_output) else "cpu"
if str(device) == "mps":
device = model_output.device if torch.is_tensor(model_output) else torch.device("cpu")
if device.type == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device
Expand Down
7 changes: 6 additions & 1 deletion src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)

self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
self.timesteps = torch.from_numpy(timesteps).to(device=device)

self.derivatives = []

Expand Down
3 changes: 3 additions & 0 deletions tests/models/test_models_unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4")
latents = self.get_latents(seed)
Expand Down Expand Up @@ -507,6 +508,7 @@ def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice):
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5")
latents = self.get_latents(seed)
Expand Down Expand Up @@ -558,6 +560,7 @@ def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice):
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting")
latents = self.get_latents(seed, shape=(4, 9, 64, 64))
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
from parameterized import parameterized
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_local_custom_pipeline(self):
assert output_str == "This is a local test"

@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
@require_torch_gpu
def test_load_pipeline_from_git(self):
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"

Expand Down
87 changes: 86 additions & 1 deletion tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def check_over_configs(self, time_step=0, **config):

num_inference_steps = kwargs.pop("num_inference_steps", None)

# TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default
for scheduler_class in self.scheduler_classes:
# TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
time_step = float(time_step)

Expand Down Expand Up @@ -1192,6 +1192,31 @@ def test_full_loop_no_noise(self):
assert abs(result_sum.item() - 1006.388) < 1e-2
assert abs(result_mean.item() - 1.31) < 1e-3

def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)

scheduler.set_timesteps(self.num_inference_steps, device=torch_device)

model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)

for i, t in enumerate(scheduler.timesteps):
sample = scheduler.scale_model_input(sample, t)

model_output = model(sample, t)

output = scheduler.step(model_output, t, sample)
sample = output.prev_sample

result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))

assert abs(result_sum.item() - 1006.388) < 1e-2
assert abs(result_mean.item() - 1.31) < 1e-3


class EulerDiscreteSchedulerTest(SchedulerCommonTest):
scheduler_classes = (EulerDiscreteScheduler,)
Expand Down Expand Up @@ -1248,6 +1273,34 @@ def test_full_loop_no_noise(self):
assert abs(result_sum.item() - 10.0807) < 1e-2
assert abs(result_mean.item() - 0.0131) < 1e-3

def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)

scheduler.set_timesteps(self.num_inference_steps, device=torch_device)

generator = torch.Generator().manual_seed(0)

model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)

for t in scheduler.timesteps:
sample = scheduler.scale_model_input(sample, t)

model_output = model(sample, t)

output = scheduler.step(model_output, t, sample, generator=generator)
sample = output.prev_sample

result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
print(result_sum, result_mean)

assert abs(result_sum.item() - 10.0807) < 1e-2
assert abs(result_mean.item() - 0.0131) < 1e-3


class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
scheduler_classes = (EulerAncestralDiscreteScheduler,)
Expand Down Expand Up @@ -1303,6 +1356,38 @@ def test_full_loop_no_noise(self):
assert abs(result_sum.item() - 152.3192) < 1e-2
assert abs(result_mean.item() - 0.1983) < 1e-3

def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)

scheduler.set_timesteps(self.num_inference_steps, device=torch_device)

generator = torch.Generator().manual_seed(0)

model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)

for t in scheduler.timesteps:
sample = scheduler.scale_model_input(sample, t)

model_output = model(sample, t)

output = scheduler.step(model_output, t, sample, generator=generator)
sample = output.prev_sample

result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
print(result_sum, result_mean)
if not str(torch_device).startswith("mps"):
# The following sum varies between 148 and 156 on mps. Why?
assert abs(result_sum.item() - 152.3192) < 1e-2
assert abs(result_mean.item() - 0.1983) < 1e-3
else:
# Larger tolerance on mps
assert abs(result_mean.item() - 0.1983) < 1e-2


class IPNDMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (IPNDMScheduler,)
Expand Down