@@ -59,10 +59,12 @@ def test_from_pretrained_save_pretrained(self):
5959 schedular = DDPMScheduler (num_train_timesteps = 10 )
6060
6161 ddpm = DDPMPipeline (model , schedular )
62+ ddpm .to (torch_device )
6263
6364 with tempfile .TemporaryDirectory () as tmpdirname :
6465 ddpm .save_pretrained (tmpdirname )
6566 new_ddpm = DDPMPipeline .from_pretrained (tmpdirname )
67+ new_ddpm .to (torch_device )
6668
6769 generator = torch .manual_seed (0 )
6870
@@ -76,11 +78,12 @@ def test_from_pretrained_save_pretrained(self):
7678 def test_from_pretrained_hub (self ):
7779 model_path = "google/ddpm-cifar10-32"
7880
79- ddpm = DDPMPipeline .from_pretrained (model_path )
80- ddpm_from_hub = DiffusionPipeline .from_pretrained (model_path )
81+ scheduler = DDPMScheduler (num_train_timesteps = 10 )
8182
82- ddpm .scheduler .num_timesteps = 10
83- ddpm_from_hub .scheduler .num_timesteps = 10
83+ ddpm = DDPMPipeline .from_pretrained (model_path , scheduler = scheduler )
84+ ddpm .to (torch_device )
85+ ddpm_from_hub = DiffusionPipeline .from_pretrained (model_path , scheduler = scheduler )
86+ ddpm_from_hub .to (torch_device )
8487
8588 generator = torch .manual_seed (0 )
8689
@@ -94,14 +97,15 @@ def test_from_pretrained_hub(self):
9497 def test_from_pretrained_hub_pass_model (self ):
9598 model_path = "google/ddpm-cifar10-32"
9699
100+ scheduler = DDPMScheduler (num_train_timesteps = 10 )
101+
97102 # pass unet into DiffusionPipeline
98103 unet = UNet2DModel .from_pretrained (model_path )
99- ddpm_from_hub_custom_model = DiffusionPipeline .from_pretrained (model_path , unet = unet )
100-
101- ddpm_from_hub = DiffusionPipeline .from_pretrained (model_path )
104+ ddpm_from_hub_custom_model = DiffusionPipeline .from_pretrained (model_path , unet = unet , scheduler = scheduler )
105+ ddpm_from_hub_custom_model .to (torch_device )
102106
103- ddpm_from_hub_custom_model . scheduler . num_timesteps = 10
104- ddpm_from_hub .scheduler . num_timesteps = 10
107+ ddpm_from_hub = DiffusionPipeline . from_pretrained ( model_path , scheduler = scheduler )
108+ ddpm_from_hub .to ( torch_device )
105109
106110 generator = torch .manual_seed (0 )
107111
@@ -116,6 +120,7 @@ def test_output_format(self):
116120 model_path = "google/ddpm-cifar10-32"
117121
118122 pipe = DDIMPipeline .from_pretrained (model_path )
123+ pipe .to (torch_device )
119124
120125 generator = torch .manual_seed (0 )
121126 images = pipe (generator = generator , output_type = "numpy" )["sample" ]
@@ -141,6 +146,7 @@ def test_ddpm_cifar10(self):
141146 scheduler = scheduler .set_format ("pt" )
142147
143148 ddpm = DDPMPipeline (unet = unet , scheduler = scheduler )
149+ ddpm .to (torch_device )
144150
145151 generator = torch .manual_seed (0 )
146152 image = ddpm (generator = generator , output_type = "numpy" )["sample" ]
@@ -159,6 +165,7 @@ def test_ddim_lsun(self):
159165 scheduler = DDIMScheduler .from_config (model_id )
160166
161167 ddpm = DDIMPipeline (unet = unet , scheduler = scheduler )
168+ ddpm .to (torch_device )
162169
163170 generator = torch .manual_seed (0 )
164171 image = ddpm (generator = generator , output_type = "numpy" )["sample" ]
@@ -177,6 +184,7 @@ def test_ddim_cifar10(self):
177184 scheduler = DDIMScheduler (tensor_format = "pt" )
178185
179186 ddim = DDIMPipeline (unet = unet , scheduler = scheduler )
187+ ddim .to (torch_device )
180188
181189 generator = torch .manual_seed (0 )
182190 image = ddim (generator = generator , eta = 0.0 , output_type = "numpy" )["sample" ]
@@ -195,6 +203,7 @@ def test_pndm_cifar10(self):
195203 scheduler = PNDMScheduler (tensor_format = "pt" )
196204
197205 pndm = PNDMPipeline (unet = unet , scheduler = scheduler )
206+ pndm .to (torch_device )
198207 generator = torch .manual_seed (0 )
199208 image = pndm (generator = generator , output_type = "numpy" )["sample" ]
200209
@@ -207,6 +216,7 @@ def test_pndm_cifar10(self):
207216 @slow
208217 def test_ldm_text2img (self ):
209218 ldm = LDMTextToImagePipeline .from_pretrained ("CompVis/ldm-text2im-large-256" )
219+ ldm .to (torch_device )
210220
211221 prompt = "A painting of a squirrel eating a burger"
212222 generator = torch .manual_seed (0 )
@@ -223,6 +233,7 @@ def test_ldm_text2img(self):
223233 @slow
224234 def test_ldm_text2img_fast (self ):
225235 ldm = LDMTextToImagePipeline .from_pretrained ("CompVis/ldm-text2im-large-256" )
236+ ldm .to (torch_device )
226237
227238 prompt = "A painting of a squirrel eating a burger"
228239 generator = torch .manual_seed (0 )
@@ -290,6 +301,7 @@ def test_score_sde_ve_pipeline(self):
290301 scheduler = ScoreSdeVeScheduler .from_config (model_id )
291302
292303 sde_ve = ScoreSdeVePipeline (unet = model , scheduler = scheduler )
304+ sde_ve .to (torch_device )
293305
294306 torch .manual_seed (0 )
295307 image = sde_ve (num_inference_steps = 300 , output_type = "numpy" )["sample" ]
@@ -304,6 +316,7 @@ def test_score_sde_ve_pipeline(self):
304316 @slow
305317 def test_ldm_uncond (self ):
306318 ldm = LDMPipeline .from_pretrained ("CompVis/ldm-celebahq-256" )
319+ ldm .to (torch_device )
307320
308321 generator = torch .manual_seed (0 )
309322 image = ldm (generator = generator , num_inference_steps = 5 , output_type = "numpy" )["sample" ]
@@ -323,7 +336,9 @@ def test_ddpm_ddim_equality(self):
323336 ddim_scheduler = DDIMScheduler (tensor_format = "pt" )
324337
325338 ddpm = DDPMPipeline (unet = unet , scheduler = ddpm_scheduler )
339+ ddpm .to (torch_device )
326340 ddim = DDIMPipeline (unet = unet , scheduler = ddim_scheduler )
341+ ddim .to (torch_device )
327342
328343 generator = torch .manual_seed (0 )
329344 ddpm_image = ddpm (generator = generator , output_type = "numpy" )["sample" ]
@@ -343,7 +358,10 @@ def test_ddpm_ddim_equality_batched(self):
343358 ddim_scheduler = DDIMScheduler (tensor_format = "pt" )
344359
345360 ddpm = DDPMPipeline (unet = unet , scheduler = ddpm_scheduler )
361+ ddpm .to (torch_device )
362+
346363 ddim = DDIMPipeline (unet = unet , scheduler = ddim_scheduler )
364+ ddim .to (torch_device )
347365
348366 generator = torch .manual_seed (0 )
349367 ddpm_images = ddpm (batch_size = 4 , generator = generator , output_type = "numpy" )["sample" ]
@@ -363,6 +381,7 @@ def test_karras_ve_pipeline(self):
363381 scheduler = KarrasVeScheduler (tensor_format = "pt" )
364382
365383 pipe = KarrasVePipeline (unet = model , scheduler = scheduler )
384+ pipe .to (torch_device )
366385
367386 generator = torch .manual_seed (0 )
368387 image = pipe (num_inference_steps = 20 , generator = generator , output_type = "numpy" )["sample" ]
0 commit comments