Skip to content

Commit 9b7ce93

Browse files
authored
Add JIT support for SharedImage._interpolate_tensor
* Added JIT support for SharedImage's interpolation operations. * Unfortunately, JIT support required me to separate SharedImage's bilinear and trilinear resizing into separate functions as Union's of tuples are currently broken. Union support was also a newer addition, so now SharedImage can support older PyTorch versions as well.
1 parent cf921ba commit 9b7ce93

File tree

2 files changed

+157
-12
lines changed

2 files changed

+157
-12
lines changed

captum/optim/_param/image/images.py

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,12 @@ class SharedImage(AugmentedImageParameterization):
454454
https://distill.pub/2018/differentiable-parameterizations/
455455
"""
456456

457-
__constants__ = ["offset", "_supports_is_scripting"]
457+
__constants__ = [
458+
"offset",
459+
"_supports_is_scripting",
460+
"_has_align_corners",
461+
"_has_recompute_scale_factor",
462+
]
458463

459464
def __init__(
460465
self,
@@ -491,6 +496,8 @@ def __init__(
491496

492497
# Check & store whether or not we can use torch.jit.is_scripting()
493498
self._supports_is_scripting = torch.__version__ >= "1.6.0"
499+
self._has_align_corners = torch.__version__ >= "1.3.0"
500+
self._has_recompute_scale_factor = torch.__version__ >= "1.6.0"
494501

495502
def _get_offset(self, offset: Union[int, Tuple[int]], n: int) -> List[List[int]]:
496503
"""
@@ -551,7 +558,75 @@ def _apply_offset(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
551558
A.append(x)
552559
return A
553560

554-
@torch.jit.ignore
561+
def _interpolate_bilinear(
562+
self,
563+
x: torch.Tensor,
564+
size: Tuple[int, int],
565+
) -> torch.Tensor:
566+
"""
567+
Perform interpolation without any warnings.
568+
569+
Args:
570+
571+
x (torch.Tensor): The NCHW tensor to resize.
572+
size (tuple of int): The desired output size to resize the input
573+
to, with a format of: [height, width].
574+
575+
Returns:
576+
x (torch.Tensor): A resized NCHW tensor.
577+
"""
578+
assert x.dim() == 4
579+
assert len(size) == 2
580+
581+
if self._has_align_corners:
582+
if self._has_recompute_scale_factor:
583+
x = F.interpolate(
584+
x,
585+
size=size,
586+
mode="bilinear",
587+
align_corners=False,
588+
recompute_scale_factor=False,
589+
)
590+
else:
591+
x = F.interpolate(x, size=size, mode="bilinear", align_corners=False)
592+
else:
593+
x = F.interpolate(x, size=size, mode="bilinear")
594+
return x
595+
596+
def _interpolate_trilinear(
597+
self,
598+
x: torch.Tensor,
599+
size: Tuple[int, int, int],
600+
) -> torch.Tensor:
601+
"""
602+
Perform interpolation without any warnings.
603+
604+
Args:
605+
606+
x (torch.Tensor): The NCHW tensor to resize.
607+
size (tuple of int): The desired output size to resize the input
608+
to, with a format of: [channels, height, width].
609+
610+
Returns:
611+
x (torch.Tensor): A resized NCHW tensor.
612+
"""
613+
x = x.unsqueeze(0)
614+
assert x.dim() == 5
615+
if self._has_align_corners:
616+
if self._has_recompute_scale_factor:
617+
x = F.interpolate(
618+
x,
619+
size=size,
620+
mode="trilinear",
621+
align_corners=False,
622+
recompute_scale_factor=False,
623+
)
624+
else:
625+
x = F.interpolate(x, size=size, mode="trilinear", align_corners=False)
626+
else:
627+
x = F.interpolate(x, size=size, mode="trilinear")
628+
return x.squeeze(0)
629+
555630
def _interpolate_tensor(
556631
self, x: torch.Tensor, batch: int, channels: int, height: int, width: int
557632
) -> torch.Tensor:
@@ -572,21 +647,14 @@ def _interpolate_tensor(
572647
"""
573648

574649
if x.size(1) == channels:
575-
mode = "bilinear"
576650
size = (height, width)
651+
x = self._interpolate_bilinear(x, size=size)
577652
else:
578-
mode = "trilinear"
579-
x = x.unsqueeze(0)
580653
size = (channels, height, width)
581-
x = F.interpolate(x, size=size, mode=mode)
582-
x = x.squeeze(0) if len(size) == 3 else x
654+
x = self._interpolate_trilinear(x, size=size)
583655
if x.size(0) != batch:
584656
x = x.permute(1, 0, 2, 3)
585-
x = F.interpolate(
586-
x.unsqueeze(0),
587-
size=(batch, x.size(2), x.size(3)),
588-
mode="trilinear",
589-
).squeeze(0)
657+
x = self._interpolate_trilinear(x, size=(batch, x.size(2), x.size(3)))
590658
x = x.permute(1, 0, 2, 3)
591659
return x
592660

tests/optim/param/test_images.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,74 @@ def test_subclass(self) -> None:
424424
issubclass(images.SharedImage, images.AugmentedImageParameterization)
425425
)
426426

427+
def test_sharedimage_interpolate_bilinear(self) -> None:
428+
shared_shapes = (128 // 2, 128 // 2)
429+
test_param = lambda: torch.ones(3, 3, 224, 224) # noqa: E731
430+
image_param = images.SharedImage(
431+
shapes=shared_shapes, parameterization=test_param
432+
)
433+
434+
size = (224, 128)
435+
test_input = torch.randn(1, 3, 128, 128)
436+
437+
test_output = image_param._interpolate_bilinear(test_input.clone(), size=size)
438+
expected_output = torch.nn.functional.interpolate(
439+
test_input.clone(), size=size, mode="bilinear"
440+
)
441+
assertTensorAlmostEqual(self, test_output, expected_output, 0.0)
442+
443+
size = (128, 128)
444+
test_input = torch.randn(1, 3, 224, 224)
445+
446+
test_output = image_param._interpolate_bilinear(test_input.clone(), size=size)
447+
expected_output = torch.nn.functional.interpolate(
448+
test_input.clone(), size=size, mode="bilinear"
449+
)
450+
assertTensorAlmostEqual(self, test_output, expected_output, 0.0)
451+
452+
def test_sharedimage_interpolate_trilinear(self) -> None:
453+
shared_shapes = (128 // 2, 128 // 2)
454+
test_param = lambda: torch.ones(3, 3, 224, 224) # noqa: E731
455+
image_param = images.SharedImage(
456+
shapes=shared_shapes, parameterization=test_param
457+
)
458+
459+
size = (3, 224, 128)
460+
test_input = torch.randn(1, 1, 128, 128)
461+
462+
test_output = image_param._interpolate_trilinear(test_input.clone(), size=size)
463+
expected_output = torch.nn.functional.interpolate(
464+
test_input.clone().unsqueeze(0), size=size, mode="trilinear"
465+
).squeeze(0)
466+
assertTensorAlmostEqual(self, test_output, expected_output, 0.0)
467+
468+
size = (2, 128, 128)
469+
test_input = torch.randn(1, 4, 224, 224)
470+
471+
test_output = image_param._interpolate_trilinear(test_input.clone(), size=size)
472+
expected_output = torch.nn.functional.interpolate(
473+
test_input.clone().unsqueeze(0), size=size, mode="trilinear"
474+
).squeeze(0)
475+
assertTensorAlmostEqual(self, test_output, expected_output, 0.0)
476+
477+
def test_torch_version_check(self) -> None:
478+
shared_shapes = (128 // 2, 128 // 2)
479+
test_param = lambda: torch.ones(3, 3, 224, 224) # noqa: E731
480+
image_param = images.SharedImage(
481+
shapes=shared_shapes, parameterization=test_param
482+
)
483+
484+
has_align_corners = torch.__version__ >= "1.3.0"
485+
self.assertEqual(image_param._has_align_corners, has_align_corners)
486+
487+
has_recompute_scale_factor = torch.__version__ >= "1.6.0"
488+
self.assertEqual(
489+
image_param._has_recompute_scale_factor, has_recompute_scale_factor
490+
)
491+
492+
supports_is_scripting = torch.__version__ >= "1.6.0"
493+
self.assertEqual(image_param._supports_is_scripting, supports_is_scripting)
494+
427495
def test_sharedimage_get_offset_single_number(self) -> None:
428496
if torch.__version__ <= "1.2.0":
429497
raise unittest.SkipTest(
@@ -772,6 +840,15 @@ def test_subclass(self) -> None:
772840
issubclass(images.StackImage, images.AugmentedImageParameterization)
773841
)
774842

843+
def test_stackimage_torch_version_check(self) -> None:
844+
img_param_1 = images.SimpleTensorParameterization(torch.ones(1, 3, 4, 4))
845+
img_param_2 = images.SimpleTensorParameterization(torch.ones(1, 3, 4, 4))
846+
param_list = [img_param_1, img_param_2]
847+
stack_param = images.StackImage(parameterizations=param_list)
848+
849+
supports_is_scripting = torch.__version__ >= "1.6.0"
850+
self.assertEqual(stack_param._supports_is_scripting, supports_is_scripting)
851+
775852
def test_stackimage_init(self) -> None:
776853
if torch.__version__ <= "1.2.0":
777854
raise unittest.SkipTest(

0 commit comments

Comments
 (0)