@@ -215,6 +215,47 @@ def test_stable_diffusion_inpaint(self):
215215 assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
216216 assert np .abs (image_from_tuple_slice .flatten () - expected_slice ).max () < 1e-2
217217
218+ def test_stable_diffusion_inpaint_with_num_images_per_prompt (self ):
219+ device = "cpu"
220+ unet = self .dummy_cond_unet_inpaint
221+ scheduler = PNDMScheduler (skip_prk_steps = True )
222+ vae = self .dummy_vae
223+ bert = self .dummy_text_encoder
224+ tokenizer = CLIPTokenizer .from_pretrained ("hf-internal-testing/tiny-random-clip" )
225+
226+ image = self .dummy_image .cpu ().permute (0 , 2 , 3 , 1 )[0 ]
227+ init_image = Image .fromarray (np .uint8 (image )).convert ("RGB" ).resize ((128 , 128 ))
228+ mask_image = Image .fromarray (np .uint8 (image + 4 )).convert ("RGB" ).resize ((128 , 128 ))
229+
230+ # make sure here that pndm scheduler skips prk
231+ sd_pipe = StableDiffusionInpaintPipeline (
232+ unet = unet ,
233+ scheduler = scheduler ,
234+ vae = vae ,
235+ text_encoder = bert ,
236+ tokenizer = tokenizer ,
237+ safety_checker = None ,
238+ feature_extractor = None ,
239+ )
240+ sd_pipe = sd_pipe .to (device )
241+ sd_pipe .set_progress_bar_config (disable = None )
242+
243+ prompt = "A painting of a squirrel eating a burger"
244+ generator = torch .Generator (device = device ).manual_seed (0 )
245+ images = sd_pipe (
246+ [prompt ],
247+ generator = generator ,
248+ guidance_scale = 6.0 ,
249+ num_inference_steps = 2 ,
250+ output_type = "np" ,
251+ image = init_image ,
252+ mask_image = mask_image ,
253+ num_images_per_prompt = 2 ,
254+ ).images
255+
256+ # check if the output is a list of 2 images
257+ assert len (images ) == 2
258+
218259 @unittest .skipIf (torch_device != "cuda" , "This test requires a GPU" )
219260 def test_stable_diffusion_inpaint_fp16 (self ):
220261 """Test that stable diffusion inpaint_legacy works with fp16"""
0 commit comments