@@ -181,11 +181,76 @@ def test_inference(self):
181181 max_diff = np .abs (image_slice .flatten () - expected_slice ).max ()
182182 self .assertLessEqual (max_diff , 1e-3 )
183183
184+ def test_inference_with_embeddings_and_multiple_images (self ):
185+ components = self .get_dummy_components ()
186+ pipe = self .pipeline_class (** components )
187+ pipe .to (torch_device )
188+ pipe .set_progress_bar_config (disable = None )
189+
190+ inputs = self .get_dummy_inputs (torch_device )
191+
192+ prompt = inputs ["prompt" ]
193+ generator = inputs ["generator" ]
194+ num_inference_steps = inputs ["num_inference_steps" ]
195+ output_type = inputs ["output_type" ]
196+
197+ prompt_embeds , negative_prompt_embeds = pipe .encode_prompt (prompt )
198+
199+ # inputs with prompt converted to embeddings
200+ inputs = {
201+ "prompt_embeds" : prompt_embeds ,
202+ "negative_prompt" : None ,
203+ "negative_prompt_embeds" : negative_prompt_embeds ,
204+ "generator" : generator ,
205+ "num_inference_steps" : num_inference_steps ,
206+ "output_type" : output_type ,
207+ "num_images_per_prompt" : 2 ,
208+ }
209+
210+ # set all optional components to None
211+ for optional_component in pipe ._optional_components :
212+ setattr (pipe , optional_component , None )
213+
214+ output = pipe (** inputs )[0 ]
215+
216+ with tempfile .TemporaryDirectory () as tmpdir :
217+ pipe .save_pretrained (tmpdir )
218+ pipe_loaded = self .pipeline_class .from_pretrained (tmpdir )
219+ pipe_loaded .to (torch_device )
220+ pipe_loaded .set_progress_bar_config (disable = None )
221+
222+ for optional_component in pipe ._optional_components :
223+ self .assertTrue (
224+ getattr (pipe_loaded , optional_component ) is None ,
225+ f"`{ optional_component } ` did not stay set to None after loading." ,
226+ )
227+
228+ inputs = self .get_dummy_inputs (torch_device )
229+
230+ generator = inputs ["generator" ]
231+ num_inference_steps = inputs ["num_inference_steps" ]
232+ output_type = inputs ["output_type" ]
233+
234+ # inputs with prompt converted to embeddings
235+ inputs = {
236+ "prompt_embeds" : prompt_embeds ,
237+ "negative_prompt" : None ,
238+ "negative_prompt_embeds" : negative_prompt_embeds ,
239+ "generator" : generator ,
240+ "num_inference_steps" : num_inference_steps ,
241+ "output_type" : output_type ,
242+ "num_images_per_prompt" : 2 ,
243+ }
244+
245+ output_loaded = pipe_loaded (** inputs )[0 ]
246+
247+ max_diff = np .abs (to_np (output ) - to_np (output_loaded )).max ()
248+ self .assertLess (max_diff , 1e-4 )
249+
184250 def test_inference_batch_single_identical (self ):
185251 self ._test_inference_batch_single_identical (expected_max_diff = 1e-3 )
186252
187253
188- # TODO: needs to be updated.
189254@slow
190255@require_torch_gpu
191256class PixArtAlphaPipelineIntegrationTests (unittest .TestCase ):
0 commit comments