Skip to content

Commit cf921ba

Browse files
authored
More tests & new AugmentedImageParameterization base class
* Added `AugmentedImageParameterization` class to use a base for `SharedImage` and `StackImage`. * Removed `PixelImage`'s 3 channel assert, as there was no reason for limitation. * Added tests for `InputParameterization`, `ImageParameterization`, & `AugmentedImageParameterization`.
1 parent 26d9519 commit cf921ba

File tree

2 files changed

+80
-17
lines changed

2 files changed

+80
-17
lines changed

captum/optim/_param/image/images.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ class ImageParameterization(InputParameterization):
133133
pass
134134

135135

136+
class AugmentedImageParameterization(ImageParameterization):
137+
pass
138+
139+
136140
class FFTImage(ImageParameterization):
137141
"""
138142
Parameterize an image using inverse real 2D FFT
@@ -305,8 +309,6 @@ def __init__(
305309
assert init.dim() == 3 or init.dim() == 4
306310
if init.dim() == 3:
307311
init = init.unsqueeze(0)
308-
assert init.shape[1] == 3, "PixelImage init should have 3 channels, "
309-
f"input has {init.shape[1]} channels."
310312
self.image = nn.Parameter(init)
311313

312314
def forward(self) -> torch.Tensor:
@@ -432,10 +434,14 @@ def __init__(self, tensor: torch.Tensor = None) -> None:
432434
self.tensor = tensor
433435

434436
def forward(self) -> torch.Tensor:
437+
"""
438+
Returns:
439+
tensor (torch.Tensor): The tensor stored during initialization.
440+
"""
435441
return self.tensor
436442

437443

438-
class SharedImage(ImageParameterization):
444+
class SharedImage(AugmentedImageParameterization):
439445
"""
440446
Share some image parameters across the batch to increase spatial alignment,
441447
by using interpolated lower resolution tensors.
@@ -585,6 +591,10 @@ def _interpolate_tensor(
585591
return x
586592

587593
def forward(self) -> torch.Tensor:
594+
"""
595+
Returns:
596+
output (torch.Tensor): An NCHW image parameterization output.
597+
"""
588598
image = self.parameterization()
589599
x = [
590600
self._interpolate_tensor(
@@ -606,7 +616,7 @@ def forward(self) -> torch.Tensor:
606616
return output.refine_names("B", "C", "H", "W")
607617

608618

609-
class StackImage(ImageParameterization):
619+
class StackImage(AugmentedImageParameterization):
610620
"""
611621
Stack multiple NCHW image parameterizations along their batch dimensions.
612622
"""
@@ -691,7 +701,9 @@ def __init__(
691701
channels: int = 3,
692702
batch: int = 1,
693703
init: Optional[torch.Tensor] = None,
694-
parameterization: ImageParameterization = FFTImage,
704+
parameterization: Union[
705+
ImageParameterization, AugmentedImageParameterization
706+
] = FFTImage,
695707
squash_func: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
696708
decorrelation_module: Optional[nn.Module] = ToRGB(transform="klt"),
697709
decorrelate_init: bool = True,
@@ -758,9 +770,11 @@ def forward(self) -> torch.Tensor:
758770
"ImageTensor",
759771
"InputParameterization",
760772
"ImageParameterization",
773+
"AugmentedImageParameterization",
761774
"FFTImage",
762775
"PixelImage",
763776
"LaplacianImage",
764777
"SharedImage",
778+
"StackImage",
765779
"NaturalImage",
766780
]

tests/optim/param/test_images.py

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def test_export_and_open_local_image(self) -> None:
7979
self.assertTrue(torch.is_tensor(new_tensor))
8080
assertTensorAlmostEqual(self, image_tensor, new_tensor)
8181

82-
def test_natural_image_cuda(self) -> None:
82+
def test_image_tensor_cuda(self) -> None:
8383
if not torch.cuda.is_available():
8484
raise unittest.SkipTest(
8585
"Skipping ImageTensor CUDA test due to not supporting CUDA."
@@ -88,7 +88,31 @@ def test_natural_image_cuda(self) -> None:
8888
self.assertTrue(image_t.is_cuda)
8989

9090

91+
class TestInputParameterization(BaseTest):
92+
def test_subclass(self) -> None:
93+
self.assertTrue(issubclass(images.InputParameterization, torch.nn.Module))
94+
95+
96+
class TestImageParameterization(BaseTest):
97+
def test_subclass(self) -> None:
98+
self.assertTrue(
99+
issubclass(images.ImageParameterization, images.InputParameterization)
100+
)
101+
102+
103+
class TestAugmentedImageParameterization(BaseTest):
104+
def test_subclass(self) -> None:
105+
self.assertTrue(
106+
issubclass(
107+
images.AugmentedImageParameterization, images.ImageParameterization
108+
)
109+
)
110+
111+
91112
class TestFFTImage(BaseTest):
113+
def test_subclass(self) -> None:
114+
self.assertTrue(issubclass(images.FFTImage, images.ImageParameterization))
115+
92116
def test_pytorch_fftfreq(self) -> None:
93117
image = images.FFTImage((1, 1))
94118
_, _, fftfreq = image.get_fft_funcs()
@@ -219,6 +243,9 @@ def test_fftimage_forward_init_batch(self) -> None:
219243

220244

221245
class TestPixelImage(BaseTest):
246+
def test_subclass(self) -> None:
247+
self.assertTrue(issubclass(images.PixelImage, images.ImageParameterization))
248+
222249
def test_pixelimage_random(self) -> None:
223250
if torch.__version__ <= "1.2.0":
224251
raise unittest.SkipTest(
@@ -251,17 +278,6 @@ def test_pixelimage_init(self) -> None:
251278
self.assertEqual(image_param.image.size(3), size[1])
252279
assertTensorAlmostEqual(self, image_param.image, init_tensor, 0)
253280

254-
def test_pixelimage_init_error(self) -> None:
255-
if torch.__version__ <= "1.2.0":
256-
raise unittest.SkipTest(
257-
"Skipping PixelImage init due to insufficient Torch version."
258-
)
259-
size = (224, 224)
260-
channels = 2
261-
init_tensor = torch.randn(channels, *size)
262-
with self.assertRaises(AssertionError):
263-
images.PixelImage(size=size, channels=channels, init=init_tensor)
264-
265281
def test_pixelimage_random_forward(self) -> None:
266282
if torch.__version__ <= "1.2.0":
267283
raise unittest.SkipTest(
@@ -298,6 +314,9 @@ def test_pixelimage_init_forward(self) -> None:
298314

299315

300316
class TestLaplacianImage(BaseTest):
317+
def test_subclass(self) -> None:
318+
self.assertTrue(issubclass(images.LaplacianImage, images.ImageParameterization))
319+
301320
def test_laplacianimage_random_forward(self) -> None:
302321
if torch.__version__ <= "1.2.0":
303322
raise unittest.SkipTest(
@@ -326,6 +345,13 @@ def test_laplacianimage_init(self) -> None:
326345

327346

328347
class TestSimpleTensorParameterization(BaseTest):
348+
def test_subclass(self) -> None:
349+
self.assertTrue(
350+
issubclass(
351+
images.SimpleTensorParameterization, images.ImageParameterization
352+
)
353+
)
354+
329355
def test_simple_tensor_parameterization_no_grad(self) -> None:
330356
test_input = torch.randn(1, 3, 4, 4)
331357
image_param = images.SimpleTensorParameterization(test_input)
@@ -337,6 +363,11 @@ def test_simple_tensor_parameterization_no_grad(self) -> None:
337363
self.assertFalse(image_param.tensor.requires_grad)
338364

339365
def test_simple_tensor_parameterization_jit_module_no_grad(self) -> None:
366+
if torch.__version__ <= "1.8.0":
367+
raise unittest.SkipTest(
368+
"Skipping SimpleTensorParameterization JIT module test due to"
369+
+ " insufficient Torch version."
370+
)
340371
test_input = torch.randn(1, 3, 4, 4)
341372
image_param = images.SimpleTensorParameterization(test_input)
342373
jit_image_param = torch.jit.script(image_param)
@@ -356,6 +387,11 @@ def test_simple_tensor_parameterization_with_grad(self) -> None:
356387
self.assertTrue(image_param.tensor.requires_grad)
357388

358389
def test_simple_tensor_parameterization_jit_module_with_grad(self) -> None:
390+
if torch.__version__ <= "1.8.0":
391+
raise unittest.SkipTest(
392+
"Skipping SimpleTensorParameterization JIT module test due to"
393+
+ " insufficient Torch version."
394+
)
359395
test_input = torch.nn.Parameter(torch.randn(1, 3, 4, 4))
360396
image_param = images.SimpleTensorParameterization(test_input)
361397
jit_image_param = torch.jit.script(image_param)
@@ -383,6 +419,11 @@ def test_simple_tensor_parameterization_cuda(self) -> None:
383419

384420

385421
class TestSharedImage(BaseTest):
422+
def test_subclass(self) -> None:
423+
self.assertTrue(
424+
issubclass(images.SharedImage, images.AugmentedImageParameterization)
425+
)
426+
386427
def test_sharedimage_get_offset_single_number(self) -> None:
387428
if torch.__version__ <= "1.2.0":
388429
raise unittest.SkipTest(
@@ -726,6 +767,11 @@ def test_sharedimage_multiple_shapes_diff_len_forward_jit_module(self) -> None:
726767

727768

728769
class TestStackImage(BaseTest):
770+
def test_subclass(self) -> None:
771+
self.assertTrue(
772+
issubclass(images.StackImage, images.AugmentedImageParameterization)
773+
)
774+
729775
def test_stackimage_init(self) -> None:
730776
if torch.__version__ <= "1.2.0":
731777
raise unittest.SkipTest(
@@ -940,6 +986,9 @@ def test_stackimage_forward_multi_device_cpu_gpu(self) -> None:
940986

941987

942988
class TestNaturalImage(BaseTest):
989+
def test_subclass(self) -> None:
990+
self.assertTrue(issubclass(images.NaturalImage, images.ImageParameterization))
991+
943992
def test_natural_image_0(self) -> None:
944993
if torch.__version__ <= "1.2.0":
945994
raise unittest.SkipTest(

0 commit comments

Comments
 (0)