@@ -196,11 +196,15 @@ def test_forward_with_norm_groups(self):
196196class ModelTesterMixin :
197197 main_input_name = None # overwrite in model specific tester class
198198 base_precision = 1e-3
199+ forward_requires_fresh_args = False
199200
200201 def test_from_save_pretrained (self , expected_max_diff = 5e-5 ):
201- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
202+ if self .forward_requires_fresh_args :
203+ model = self .model_class (** self .init_dict )
204+ else :
205+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
206+ model = self .model_class (** init_dict )
202207
203- model = self .model_class (** init_dict )
204208 if hasattr (model , "set_default_attn_processor" ):
205209 model .set_default_attn_processor ()
206210 model .to (torch_device )
@@ -214,11 +218,18 @@ def test_from_save_pretrained(self, expected_max_diff=5e-5):
214218 new_model .to (torch_device )
215219
216220 with torch .no_grad ():
217- image = model (** inputs_dict )
221+ if self .forward_requires_fresh_args :
222+ image = model (** self .inputs_dict (0 ))
223+ else :
224+ image = model (** inputs_dict )
225+
218226 if isinstance (image , dict ):
219227 image = image .to_tuple ()[0 ]
220228
221- new_image = new_model (** inputs_dict )
229+ if self .forward_requires_fresh_args :
230+ new_image = new_model (** self .inputs_dict (0 ))
231+ else :
232+ new_image = new_model (** inputs_dict )
222233
223234 if isinstance (new_image , dict ):
224235 new_image = new_image .to_tuple ()[0 ]
@@ -275,8 +286,11 @@ def test_getattr_is_correct(self):
275286 )
276287 def test_set_xformers_attn_processor_for_determinism (self ):
277288 torch .use_deterministic_algorithms (False )
278- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
279- model = self .model_class (** init_dict )
289+ if self .forward_requires_fresh_args :
290+ model = self .model_class (** self .init_dict )
291+ else :
292+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
293+ model = self .model_class (** init_dict )
280294 model .to (torch_device )
281295
282296 if not hasattr (model , "set_attn_processor" ):
@@ -286,17 +300,26 @@ def test_set_xformers_attn_processor_for_determinism(self):
286300 model .set_default_attn_processor ()
287301 assert all (type (proc ) == AttnProcessor for proc in model .attn_processors .values ())
288302 with torch .no_grad ():
289- output = model (** inputs_dict )[0 ]
303+ if self .forward_requires_fresh_args :
304+ output = model (** self .inputs_dict (0 ))[0 ]
305+ else :
306+ output = model (** inputs_dict )[0 ]
290307
291308 model .enable_xformers_memory_efficient_attention ()
292309 assert all (type (proc ) == XFormersAttnProcessor for proc in model .attn_processors .values ())
293310 with torch .no_grad ():
294- output_2 = model (** inputs_dict )[0 ]
311+ if self .forward_requires_fresh_args :
312+ output_2 = model (** self .inputs_dict (0 ))[0 ]
313+ else :
314+ output_2 = model (** inputs_dict )[0 ]
295315
296316 model .set_attn_processor (XFormersAttnProcessor ())
297317 assert all (type (proc ) == XFormersAttnProcessor for proc in model .attn_processors .values ())
298318 with torch .no_grad ():
299- output_3 = model (** inputs_dict )[0 ]
319+ if self .forward_requires_fresh_args :
320+ output_3 = model (** self .inputs_dict (0 ))[0 ]
321+ else :
322+ output_3 = model (** inputs_dict )[0 ]
300323
301324 torch .use_deterministic_algorithms (True )
302325
@@ -307,8 +330,12 @@ def test_set_xformers_attn_processor_for_determinism(self):
307330 @require_torch_gpu
308331 def test_set_attn_processor_for_determinism (self ):
309332 torch .use_deterministic_algorithms (False )
310- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
311- model = self .model_class (** init_dict )
333+ if self .forward_requires_fresh_args :
334+ model = self .model_class (** self .init_dict )
335+ else :
336+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
337+ model = self .model_class (** init_dict )
338+
312339 model .to (torch_device )
313340
314341 if not hasattr (model , "set_attn_processor" ):
@@ -317,22 +344,34 @@ def test_set_attn_processor_for_determinism(self):
317344
318345 assert all (type (proc ) == AttnProcessor2_0 for proc in model .attn_processors .values ())
319346 with torch .no_grad ():
320- output_1 = model (** inputs_dict )[0 ]
347+ if self .forward_requires_fresh_args :
348+ output_1 = model (** self .inputs_dict (0 ))[0 ]
349+ else :
350+ output_1 = model (** inputs_dict )[0 ]
321351
322352 model .set_default_attn_processor ()
323353 assert all (type (proc ) == AttnProcessor for proc in model .attn_processors .values ())
324354 with torch .no_grad ():
325- output_2 = model (** inputs_dict )[0 ]
355+ if self .forward_requires_fresh_args :
356+ output_2 = model (** self .inputs_dict (0 ))[0 ]
357+ else :
358+ output_2 = model (** inputs_dict )[0 ]
326359
327360 model .set_attn_processor (AttnProcessor2_0 ())
328361 assert all (type (proc ) == AttnProcessor2_0 for proc in model .attn_processors .values ())
329362 with torch .no_grad ():
330- output_4 = model (** inputs_dict )[0 ]
363+ if self .forward_requires_fresh_args :
364+ output_4 = model (** self .inputs_dict (0 ))[0 ]
365+ else :
366+ output_4 = model (** inputs_dict )[0 ]
331367
332368 model .set_attn_processor (AttnProcessor ())
333369 assert all (type (proc ) == AttnProcessor for proc in model .attn_processors .values ())
334370 with torch .no_grad ():
335- output_5 = model (** inputs_dict )[0 ]
371+ if self .forward_requires_fresh_args :
372+ output_5 = model (** self .inputs_dict (0 ))[0 ]
373+ else :
374+ output_5 = model (** inputs_dict )[0 ]
336375
337376 torch .use_deterministic_algorithms (True )
338377
@@ -342,9 +381,12 @@ def test_set_attn_processor_for_determinism(self):
342381 assert torch .allclose (output_2 , output_5 , atol = self .base_precision )
343382
344383 def test_from_save_pretrained_variant (self , expected_max_diff = 5e-5 ):
345- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
384+ if self .forward_requires_fresh_args :
385+ model = self .model_class (** self .init_dict )
386+ else :
387+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
388+ model = self .model_class (** init_dict )
346389
347- model = self .model_class (** init_dict )
348390 if hasattr (model , "set_default_attn_processor" ):
349391 model .set_default_attn_processor ()
350392
@@ -367,11 +409,17 @@ def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
367409 new_model .to (torch_device )
368410
369411 with torch .no_grad ():
370- image = model (** inputs_dict )
412+ if self .forward_requires_fresh_args :
413+ image = model (** self .inputs_dict (0 ))
414+ else :
415+ image = model (** inputs_dict )
371416 if isinstance (image , dict ):
372417 image = image .to_tuple ()[0 ]
373418
374- new_image = new_model (** inputs_dict )
419+ if self .forward_requires_fresh_args :
420+ new_image = new_model (** self .inputs_dict (0 ))
421+ else :
422+ new_image = new_model (** inputs_dict )
375423
376424 if isinstance (new_image , dict ):
377425 new_image = new_image .to_tuple ()[0 ]
@@ -405,17 +453,26 @@ def test_from_save_pretrained_dtype(self):
405453 assert new_model .dtype == dtype
406454
407455 def test_determinism (self , expected_max_diff = 1e-5 ):
408- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
409- model = self .model_class (** init_dict )
456+ if self .forward_requires_fresh_args :
457+ model = self .model_class (** self .init_dict )
458+ else :
459+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
460+ model = self .model_class (** init_dict )
410461 model .to (torch_device )
411462 model .eval ()
412463
413464 with torch .no_grad ():
414- first = model (** inputs_dict )
465+ if self .forward_requires_fresh_args :
466+ first = model (** self .inputs_dict (0 ))
467+ else :
468+ first = model (** inputs_dict )
415469 if isinstance (first , dict ):
416470 first = first .to_tuple ()[0 ]
417471
418- second = model (** inputs_dict )
472+ if self .forward_requires_fresh_args :
473+ second = model (** self .inputs_dict (0 ))
474+ else :
475+ second = model (** inputs_dict )
419476 if isinstance (second , dict ):
420477 second = second .to_tuple ()[0 ]
421478
@@ -548,15 +605,22 @@ def recursive_check(tuple_object, dict_object):
548605 ),
549606 )
550607
551- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
608+ if self .forward_requires_fresh_args :
609+ model = self .model_class (** self .init_dict )
610+ else :
611+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
612+ model = self .model_class (** init_dict )
552613
553- model = self .model_class (** init_dict )
554614 model .to (torch_device )
555615 model .eval ()
556616
557617 with torch .no_grad ():
558- outputs_dict = model (** inputs_dict )
559- outputs_tuple = model (** inputs_dict , return_dict = False )
618+ if self .forward_requires_fresh_args :
619+ outputs_dict = model (** self .inputs_dict (0 ))
620+ outputs_tuple = model (** self .inputs_dict (0 ), return_dict = False )
621+ else :
622+ outputs_dict = model (** inputs_dict )
623+ outputs_tuple = model (** inputs_dict , return_dict = False )
560624
561625 recursive_check (outputs_tuple , outputs_dict )
562626
0 commit comments