@@ -424,6 +424,74 @@ def test_subclass(self) -> None:
424424 issubclass (images .SharedImage , images .AugmentedImageParameterization )
425425 )
426426
427+ def test_sharedimage_interpolate_bilinear (self ) -> None :
428+ shared_shapes = (128 // 2 , 128 // 2 )
429+ test_param = lambda : torch .ones (3 , 3 , 224 , 224 ) # noqa: E731
430+ image_param = images .SharedImage (
431+ shapes = shared_shapes , parameterization = test_param
432+ )
433+
434+ size = (224 , 128 )
435+ test_input = torch .randn (1 , 3 , 128 , 128 )
436+
437+ test_output = image_param ._interpolate_bilinear (test_input .clone (), size = size )
438+ expected_output = torch .nn .functional .interpolate (
439+ test_input .clone (), size = size , mode = "bilinear"
440+ )
441+ assertTensorAlmostEqual (self , test_output , expected_output , 0.0 )
442+
443+ size = (128 , 128 )
444+ test_input = torch .randn (1 , 3 , 224 , 224 )
445+
446+ test_output = image_param ._interpolate_bilinear (test_input .clone (), size = size )
447+ expected_output = torch .nn .functional .interpolate (
448+ test_input .clone (), size = size , mode = "bilinear"
449+ )
450+ assertTensorAlmostEqual (self , test_output , expected_output , 0.0 )
451+
452+ def test_sharedimage_interpolate_trilinear (self ) -> None :
453+ shared_shapes = (128 // 2 , 128 // 2 )
454+ test_param = lambda : torch .ones (3 , 3 , 224 , 224 ) # noqa: E731
455+ image_param = images .SharedImage (
456+ shapes = shared_shapes , parameterization = test_param
457+ )
458+
459+ size = (3 , 224 , 128 )
460+ test_input = torch .randn (1 , 1 , 128 , 128 )
461+
462+ test_output = image_param ._interpolate_trilinear (test_input .clone (), size = size )
463+ expected_output = torch .nn .functional .interpolate (
464+ test_input .clone ().unsqueeze (0 ), size = size , mode = "trilinear"
465+ ).squeeze (0 )
466+ assertTensorAlmostEqual (self , test_output , expected_output , 0.0 )
467+
468+ size = (2 , 128 , 128 )
469+ test_input = torch .randn (1 , 4 , 224 , 224 )
470+
471+ test_output = image_param ._interpolate_trilinear (test_input .clone (), size = size )
472+ expected_output = torch .nn .functional .interpolate (
473+ test_input .clone ().unsqueeze (0 ), size = size , mode = "trilinear"
474+ ).squeeze (0 )
475+ assertTensorAlmostEqual (self , test_output , expected_output , 0.0 )
476+
477+ def test_torch_version_check (self ) -> None :
478+ shared_shapes = (128 // 2 , 128 // 2 )
479+ test_param = lambda : torch .ones (3 , 3 , 224 , 224 ) # noqa: E731
480+ image_param = images .SharedImage (
481+ shapes = shared_shapes , parameterization = test_param
482+ )
483+
484+ has_align_corners = torch .__version__ >= "1.3.0"
485+ self .assertEqual (image_param ._has_align_corners , has_align_corners )
486+
487+ has_recompute_scale_factor = torch .__version__ >= "1.6.0"
488+ self .assertEqual (
489+ image_param ._has_recompute_scale_factor , has_recompute_scale_factor
490+ )
491+
492+ supports_is_scripting = torch .__version__ >= "1.6.0"
493+ self .assertEqual (image_param ._supports_is_scripting , supports_is_scripting )
494+
427495 def test_sharedimage_get_offset_single_number (self ) -> None :
428496 if torch .__version__ <= "1.2.0" :
429497 raise unittest .SkipTest (
@@ -772,6 +840,15 @@ def test_subclass(self) -> None:
772840 issubclass (images .StackImage , images .AugmentedImageParameterization )
773841 )
774842
843+ def test_stackimage_torch_version_check (self ) -> None :
844+ img_param_1 = images .SimpleTensorParameterization (torch .ones (1 , 3 , 4 , 4 ))
845+ img_param_2 = images .SimpleTensorParameterization (torch .ones (1 , 3 , 4 , 4 ))
846+ param_list = [img_param_1 , img_param_2 ]
847+ stack_param = images .StackImage (parameterizations = param_list )
848+
849+ supports_is_scripting = torch .__version__ >= "1.6.0"
850+ self .assertEqual (stack_param ._supports_is_scripting , supports_is_scripting )
851+
775852 def test_stackimage_init (self ) -> None :
776853 if torch .__version__ <= "1.2.0" :
777854 raise unittest .SkipTest (
0 commit comments