@@ -165,17 +165,19 @@ def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False)
165165 return model
166166
167167 def get_generator (self , seed = 0 ):
168+ if torch_device == "mps" :
169+ return torch .Generator ().manual_seed (seed )
168170 return torch .Generator (device = torch_device ).manual_seed (seed )
169171
170172 @parameterized .expand (
171173 [
172174 # fmt: off
173- [33 , [- 0.1603 , 0.9878 , - 0.0495 , - 0.0790 , - 0.2709 , 0.8375 , - 0.2060 , - 0.0824 ]],
174- [47 , [- 0.2376 , 0.1168 , 0.1332 , - 0.4840 , - 0.2508 , - 0.0791 , - 0.0493 , - 0.4089 ]],
175+ [33 , [- 0.1603 , 0.9878 , - 0.0495 , - 0.0790 , - 0.2709 , 0.8375 , - 0.2060 , - 0.0824 ], [ - 0.2395 , 0.0098 , 0.0102 , - 0.0709 , - 0.2840 , - 0.0274 , - 0.0718 , - 0.1824 ] ],
176+ [47 , [- 0.2376 , 0.1168 , 0.1332 , - 0.4840 , - 0.2508 , - 0.0791 , - 0.0493 , - 0.4089 ], [ 0.0350 , 0.0847 , 0.0467 , 0.0344 , - 0.0842 , - 0.0547 , - 0.0633 , - 0.1131 ] ],
175177 # fmt: on
176178 ]
177179 )
178- def test_stable_diffusion (self , seed , expected_slice ):
180+ def test_stable_diffusion (self , seed , expected_slice , expected_slice_mps ):
179181 model = self .get_sd_vae_model ()
180182 image = self .get_sd_image (seed )
181183 generator = self .get_generator (seed )
@@ -186,7 +188,7 @@ def test_stable_diffusion(self, seed, expected_slice):
186188 assert sample .shape == image .shape
187189
188190 output_slice = sample [- 1 , - 2 :, - 2 :, :2 ].flatten ().float ().cpu ()
189- expected_output_slice = torch .tensor (expected_slice )
191+ expected_output_slice = torch .tensor (expected_slice_mps if torch_device == "mps" else expected_slice )
190192
191193 assert torch_all_close (output_slice , expected_output_slice , atol = 1e-3 )
192194
@@ -217,12 +219,12 @@ def test_stable_diffusion_fp16(self, seed, expected_slice):
217219 @parameterized .expand (
218220 [
219221 # fmt: off
220- [33 , [- 0.1609 , 0.9866 , - 0.0487 , - 0.0777 , - 0.2716 , 0.8368 , - 0.2055 , - 0.0814 ]],
221- [47 , [- 0.2377 , 0.1147 , 0.1333 , - 0.4841 , - 0.2506 , - 0.0805 , - 0.0491 , - 0.4085 ]],
222+ [33 , [- 0.1609 , 0.9866 , - 0.0487 , - 0.0777 , - 0.2716 , 0.8368 , - 0.2055 , - 0.0814 ], [ - 0.2395 , 0.0098 , 0.0102 , - 0.0709 , - 0.2840 , - 0.0274 , - 0.0718 , - 0.1824 ] ],
223+ [47 , [- 0.2377 , 0.1147 , 0.1333 , - 0.4841 , - 0.2506 , - 0.0805 , - 0.0491 , - 0.4085 ], [ 0.0350 , 0.0847 , 0.0467 , 0.0344 , - 0.0842 , - 0.0547 , - 0.0633 , - 0.1131 ] ],
222224 # fmt: on
223225 ]
224226 )
225- def test_stable_diffusion_mode (self , seed , expected_slice ):
227+ def test_stable_diffusion_mode (self , seed , expected_slice , expected_slice_mps ):
226228 model = self .get_sd_vae_model ()
227229 image = self .get_sd_image (seed )
228230
@@ -232,7 +234,7 @@ def test_stable_diffusion_mode(self, seed, expected_slice):
232234 assert sample .shape == image .shape
233235
234236 output_slice = sample [- 1 , - 2 :, - 2 :, :2 ].flatten ().float ().cpu ()
235- expected_output_slice = torch .tensor (expected_slice )
237+ expected_output_slice = torch .tensor (expected_slice_mps if torch_device == "mps" else expected_slice )
236238
237239 assert torch_all_close (output_slice , expected_output_slice , atol = 1e-3 )
238240
@@ -267,6 +269,7 @@ def test_stable_diffusion_decode(self, seed, expected_slice):
267269 # fmt: on
268270 ]
269271 )
272+ @require_torch_gpu
270273 def test_stable_diffusion_decode_fp16 (self , seed , expected_slice ):
271274 model = self .get_sd_vae_model (fp16 = True )
272275 encoding = self .get_sd_image (seed , shape = (3 , 4 , 64 , 64 ), fp16 = True )
@@ -303,4 +306,5 @@ def test_stable_diffusion_encode_sample(self, seed, expected_slice):
303306 output_slice = sample [0 , - 1 , - 3 :, - 3 :].flatten ().cpu ()
304307 expected_output_slice = torch .tensor (expected_slice )
305308
306- assert torch_all_close (output_slice , expected_output_slice , atol = 1e-3 )
309+ tolerance = 1e-3 if torch_device != "mps" else 1e-2
310+ assert torch_all_close (output_slice , expected_output_slice , atol = tolerance )
0 commit comments