Skip to content

Commit f6bcece

Browse files
committed
tests
1 parent cf33acb commit f6bcece

File tree

3 files changed

+171
-31
lines changed

3 files changed

+171
-31
lines changed

src/diffusers/models/consistency_decoder_vae.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class ConsistencyDecoderVae(ModelMixin, ConfigMixin):
5858
>>> import torch
5959
>>> from diffusers import DiffusionPipeline, ConsistencyDecoderVae
6060
61+
>>> # TODO - is this going to be where the model is uploaded?
6162
>>> vae = ConsistencyDecoderVae.from_pretrained("openai/consistency-decoder", torch_dtype=pipe.torch_dtype)
6263
>>> pipe = StableDiffusionPipeline.from_pretrained(
6364
... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
@@ -242,7 +243,9 @@ def decode(
242243
num_inference_steps=2,
243244
) -> Union[DecoderOutput, torch.FloatTensor]:
244245
z = (z - self.means) / self.stds
245-
z = F.interpolate(z, mode="nearest", scale_factor=8)
246+
247+
scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
248+
z = F.interpolate(z, mode="nearest", scale_factor=scale_factor)
246249

247250
batch_size, _, height, width = z.shape
248251

@@ -334,7 +337,6 @@ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Consis
334337

335338
return ConsistencyDecoderVaeOutput(latent_dist=posterior)
336339

337-
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.forward
338340
def forward(
339341
self,
340342
sample: torch.FloatTensor,
@@ -356,7 +358,7 @@ def forward(
356358
z = posterior.sample(generator=generator)
357359
else:
358360
z = posterior.mode()
359-
dec = self.decode(z).sample
361+
dec = self.decode(z, generator=generator).sample
360362

361363
if not return_dict:
362364
return (dec,)

tests/models/test_modeling_common.py

Lines changed: 91 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,15 @@ def test_forward_with_norm_groups(self):
196196
class ModelTesterMixin:
197197
main_input_name = None # overwrite in model specific tester class
198198
base_precision = 1e-3
199+
forward_requires_fresh_args = False
199200

200201
def test_from_save_pretrained(self, expected_max_diff=5e-5):
201-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
202+
if self.forward_requires_fresh_args:
203+
model = self.model_class(**self.init_dict)
204+
else:
205+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
206+
model = self.model_class(**init_dict)
202207

203-
model = self.model_class(**init_dict)
204208
if hasattr(model, "set_default_attn_processor"):
205209
model.set_default_attn_processor()
206210
model.to(torch_device)
@@ -214,11 +218,18 @@ def test_from_save_pretrained(self, expected_max_diff=5e-5):
214218
new_model.to(torch_device)
215219

216220
with torch.no_grad():
217-
image = model(**inputs_dict)
221+
if self.forward_requires_fresh_args:
222+
image = model(**self.inputs_dict(0))
223+
else:
224+
image = model(**inputs_dict)
225+
218226
if isinstance(image, dict):
219227
image = image.to_tuple()[0]
220228

221-
new_image = new_model(**inputs_dict)
229+
if self.forward_requires_fresh_args:
230+
new_image = new_model(**self.inputs_dict(0))
231+
else:
232+
new_image = new_model(**inputs_dict)
222233

223234
if isinstance(new_image, dict):
224235
new_image = new_image.to_tuple()[0]
@@ -275,8 +286,11 @@ def test_getattr_is_correct(self):
275286
)
276287
def test_set_xformers_attn_processor_for_determinism(self):
277288
torch.use_deterministic_algorithms(False)
278-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
279-
model = self.model_class(**init_dict)
289+
if self.forward_requires_fresh_args:
290+
model = self.model_class(**self.init_dict)
291+
else:
292+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
293+
model = self.model_class(**init_dict)
280294
model.to(torch_device)
281295

282296
if not hasattr(model, "set_attn_processor"):
@@ -286,17 +300,26 @@ def test_set_xformers_attn_processor_for_determinism(self):
286300
model.set_default_attn_processor()
287301
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
288302
with torch.no_grad():
289-
output = model(**inputs_dict)[0]
303+
if self.forward_requires_fresh_args:
304+
output = model(**self.inputs_dict(0))[0]
305+
else:
306+
output = model(**inputs_dict)[0]
290307

291308
model.enable_xformers_memory_efficient_attention()
292309
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
293310
with torch.no_grad():
294-
output_2 = model(**inputs_dict)[0]
311+
if self.forward_requires_fresh_args:
312+
output_2 = model(**self.inputs_dict(0))[0]
313+
else:
314+
output_2 = model(**inputs_dict)[0]
295315

296316
model.set_attn_processor(XFormersAttnProcessor())
297317
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
298318
with torch.no_grad():
299-
output_3 = model(**inputs_dict)[0]
319+
if self.forward_requires_fresh_args:
320+
output_3 = model(**self.inputs_dict(0))[0]
321+
else:
322+
output_3 = model(**inputs_dict)[0]
300323

301324
torch.use_deterministic_algorithms(True)
302325

@@ -307,8 +330,12 @@ def test_set_xformers_attn_processor_for_determinism(self):
307330
@require_torch_gpu
308331
def test_set_attn_processor_for_determinism(self):
309332
torch.use_deterministic_algorithms(False)
310-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
311-
model = self.model_class(**init_dict)
333+
if self.forward_requires_fresh_args:
334+
model = self.model_class(**self.init_dict)
335+
else:
336+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
337+
model = self.model_class(**init_dict)
338+
312339
model.to(torch_device)
313340

314341
if not hasattr(model, "set_attn_processor"):
@@ -317,22 +344,34 @@ def test_set_attn_processor_for_determinism(self):
317344

318345
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
319346
with torch.no_grad():
320-
output_1 = model(**inputs_dict)[0]
347+
if self.forward_requires_fresh_args:
348+
output_1 = model(**self.inputs_dict(0))[0]
349+
else:
350+
output_1 = model(**inputs_dict)[0]
321351

322352
model.set_default_attn_processor()
323353
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
324354
with torch.no_grad():
325-
output_2 = model(**inputs_dict)[0]
355+
if self.forward_requires_fresh_args:
356+
output_2 = model(**self.inputs_dict(0))[0]
357+
else:
358+
output_2 = model(**inputs_dict)[0]
326359

327360
model.set_attn_processor(AttnProcessor2_0())
328361
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
329362
with torch.no_grad():
330-
output_4 = model(**inputs_dict)[0]
363+
if self.forward_requires_fresh_args:
364+
output_4 = model(**self.inputs_dict(0))[0]
365+
else:
366+
output_4 = model(**inputs_dict)[0]
331367

332368
model.set_attn_processor(AttnProcessor())
333369
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
334370
with torch.no_grad():
335-
output_5 = model(**inputs_dict)[0]
371+
if self.forward_requires_fresh_args:
372+
output_5 = model(**self.inputs_dict(0))[0]
373+
else:
374+
output_5 = model(**inputs_dict)[0]
336375

337376
torch.use_deterministic_algorithms(True)
338377

@@ -342,9 +381,12 @@ def test_set_attn_processor_for_determinism(self):
342381
assert torch.allclose(output_2, output_5, atol=self.base_precision)
343382

344383
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
345-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
384+
if self.forward_requires_fresh_args:
385+
model = self.model_class(**self.init_dict)
386+
else:
387+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
388+
model = self.model_class(**init_dict)
346389

347-
model = self.model_class(**init_dict)
348390
if hasattr(model, "set_default_attn_processor"):
349391
model.set_default_attn_processor()
350392

@@ -367,11 +409,17 @@ def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
367409
new_model.to(torch_device)
368410

369411
with torch.no_grad():
370-
image = model(**inputs_dict)
412+
if self.forward_requires_fresh_args:
413+
image = model(**self.inputs_dict(0))
414+
else:
415+
image = model(**inputs_dict)
371416
if isinstance(image, dict):
372417
image = image.to_tuple()[0]
373418

374-
new_image = new_model(**inputs_dict)
419+
if self.forward_requires_fresh_args:
420+
new_image = new_model(**self.inputs_dict(0))
421+
else:
422+
new_image = new_model(**inputs_dict)
375423

376424
if isinstance(new_image, dict):
377425
new_image = new_image.to_tuple()[0]
@@ -405,17 +453,26 @@ def test_from_save_pretrained_dtype(self):
405453
assert new_model.dtype == dtype
406454

407455
def test_determinism(self, expected_max_diff=1e-5):
408-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
409-
model = self.model_class(**init_dict)
456+
if self.forward_requires_fresh_args:
457+
model = self.model_class(**self.init_dict)
458+
else:
459+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
460+
model = self.model_class(**init_dict)
410461
model.to(torch_device)
411462
model.eval()
412463

413464
with torch.no_grad():
414-
first = model(**inputs_dict)
465+
if self.forward_requires_fresh_args:
466+
first = model(**self.inputs_dict(0))
467+
else:
468+
first = model(**inputs_dict)
415469
if isinstance(first, dict):
416470
first = first.to_tuple()[0]
417471

418-
second = model(**inputs_dict)
472+
if self.forward_requires_fresh_args:
473+
second = model(**self.inputs_dict(0))
474+
else:
475+
second = model(**inputs_dict)
419476
if isinstance(second, dict):
420477
second = second.to_tuple()[0]
421478

@@ -548,15 +605,22 @@ def recursive_check(tuple_object, dict_object):
548605
),
549606
)
550607

551-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
608+
if self.forward_requires_fresh_args:
609+
model = self.model_class(**self.init_dict)
610+
else:
611+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
612+
model = self.model_class(**init_dict)
552613

553-
model = self.model_class(**init_dict)
554614
model.to(torch_device)
555615
model.eval()
556616

557617
with torch.no_grad():
558-
outputs_dict = model(**inputs_dict)
559-
outputs_tuple = model(**inputs_dict, return_dict=False)
618+
if self.forward_requires_fresh_args:
619+
outputs_dict = model(**self.inputs_dict(0))
620+
outputs_tuple = model(**self.inputs_dict(0), return_dict=False)
621+
else:
622+
outputs_dict = model(**inputs_dict)
623+
outputs_tuple = model(**inputs_dict, return_dict=False)
560624

561625
recursive_check(outputs_tuple, outputs_dict)
562626

tests/models/test_models_vae.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch
2020
from parameterized import parameterized
2121

22-
from diffusers import AsymmetricAutoencoderKL, AutoencoderKL, AutoencoderTiny
22+
from diffusers import AsymmetricAutoencoderKL, AutoencoderKL, AutoencoderTiny, ConsistencyDecoderVae
2323
from diffusers.utils.import_utils import is_xformers_available
2424
from diffusers.utils.testing_utils import (
2525
enable_full_determinism,
@@ -30,6 +30,7 @@
3030
torch_all_close,
3131
torch_device,
3232
)
33+
from diffusers.utils.torch_utils import randn_tensor
3334

3435
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
3536

@@ -269,6 +270,79 @@ def test_outputs_equivalence(self):
269270
pass
270271

271272

273+
class ConsistencyDecoderVaeTests(ModelTesterMixin, unittest.TestCase):
274+
model_class = ConsistencyDecoderVae
275+
main_input_name = "sample"
276+
base_precision = 1e-2
277+
forward_requires_fresh_args = True
278+
279+
def inputs_dict(self, seed=None):
280+
generator = torch.Generator("cpu")
281+
if seed is not None:
282+
generator.manual_seed(0)
283+
image = randn_tensor((4, 3, 32, 32), generator=generator, device=torch.device(torch_device))
284+
285+
return {"sample": image, "generator": generator}
286+
287+
@property
288+
def input_shape(self):
289+
return (3, 32, 32)
290+
291+
@property
292+
def output_shape(self):
293+
return (3, 32, 32)
294+
295+
@property
296+
def init_dict(self):
297+
return {
298+
"encoder_args": {
299+
"block_out_channels": [32, 64],
300+
"in_channels": 3,
301+
"out_channels": 4,
302+
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
303+
},
304+
"decoder_args": {
305+
"act_fn": "silu",
306+
"add_attention": False,
307+
"block_out_channels": [32, 64],
308+
"down_block_types": [
309+
"ResnetDownsampleBlock2D",
310+
"ResnetDownsampleBlock2D",
311+
],
312+
"downsample_padding": 1,
313+
"downsample_type": "conv",
314+
"dropout": 0.0,
315+
"in_channels": 7,
316+
"layers_per_block": 1,
317+
"norm_eps": 1e-05,
318+
"norm_num_groups": 32,
319+
"num_train_timesteps": 1024,
320+
"out_channels": 6,
321+
"resnet_time_scale_shift": "scale_shift",
322+
"time_embedding_type": "learned",
323+
"up_block_types": [
324+
"ResnetUpsampleBlock2D",
325+
"ResnetUpsampleBlock2D",
326+
],
327+
"upsample_type": "conv",
328+
},
329+
"scaling_factor": 1,
330+
"block_out_channels": [32, 64],
331+
"latent_channels": 4,
332+
}
333+
334+
def prepare_init_args_and_inputs_for_common(self):
335+
return self.init_dict, self.inputs_dict()
336+
337+
@unittest.skip
338+
def test_training(self):
339+
...
340+
341+
@unittest.skip
342+
def test_ema_training(self):
343+
...
344+
345+
272346
@slow
273347
class AutoencoderTinyIntegrationTests(unittest.TestCase):
274348
def tearDown(self):

0 commit comments

Comments
 (0)