@@ -79,7 +79,7 @@ def test_export_and_open_local_image(self) -> None:
7979 self .assertTrue (torch .is_tensor (new_tensor ))
8080 assertTensorAlmostEqual (self , image_tensor , new_tensor )
8181
82- def test_natural_image_cuda (self ) -> None :
82+ def test_image_tensor_cuda (self ) -> None :
8383 if not torch .cuda .is_available ():
8484 raise unittest .SkipTest (
8585 "Skipping ImageTensor CUDA test due to not supporting CUDA."
@@ -88,7 +88,31 @@ def test_natural_image_cuda(self) -> None:
8888 self .assertTrue (image_t .is_cuda )
8989
9090
91+ class TestInputParameterization (BaseTest ):
92+ def test_subclass (self ) -> None :
93+ self .assertTrue (issubclass (images .InputParameterization , torch .nn .Module ))
94+
95+
96+ class TestImageParameterization (BaseTest ):
97+ def test_subclass (self ) -> None :
98+ self .assertTrue (
99+ issubclass (images .ImageParameterization , images .InputParameterization )
100+ )
101+
102+
103+ class TestAugmentedImageParameterization (BaseTest ):
104+ def test_subclass (self ) -> None :
105+ self .assertTrue (
106+ issubclass (
107+ images .AugmentedImageParameterization , images .ImageParameterization
108+ )
109+ )
110+
111+
91112class TestFFTImage (BaseTest ):
113+ def test_subclass (self ) -> None :
114+ self .assertTrue (issubclass (images .FFTImage , images .ImageParameterization ))
115+
92116 def test_pytorch_fftfreq (self ) -> None :
93117 image = images .FFTImage ((1 , 1 ))
94118 _ , _ , fftfreq = image .get_fft_funcs ()
@@ -219,6 +243,9 @@ def test_fftimage_forward_init_batch(self) -> None:
219243
220244
221245class TestPixelImage (BaseTest ):
246+ def test_subclass (self ) -> None :
247+ self .assertTrue (issubclass (images .PixelImage , images .ImageParameterization ))
248+
222249 def test_pixelimage_random (self ) -> None :
223250 if torch .__version__ <= "1.2.0" :
224251 raise unittest .SkipTest (
@@ -251,17 +278,6 @@ def test_pixelimage_init(self) -> None:
251278 self .assertEqual (image_param .image .size (3 ), size [1 ])
252279 assertTensorAlmostEqual (self , image_param .image , init_tensor , 0 )
253280
254- def test_pixelimage_init_error (self ) -> None :
255- if torch .__version__ <= "1.2.0" :
256- raise unittest .SkipTest (
257- "Skipping PixelImage init due to insufficient Torch version."
258- )
259- size = (224 , 224 )
260- channels = 2
261- init_tensor = torch .randn (channels , * size )
262- with self .assertRaises (AssertionError ):
263- images .PixelImage (size = size , channels = channels , init = init_tensor )
264-
265281 def test_pixelimage_random_forward (self ) -> None :
266282 if torch .__version__ <= "1.2.0" :
267283 raise unittest .SkipTest (
@@ -298,6 +314,9 @@ def test_pixelimage_init_forward(self) -> None:
298314
299315
300316class TestLaplacianImage (BaseTest ):
317+ def test_subclass (self ) -> None :
318+ self .assertTrue (issubclass (images .LaplacianImage , images .ImageParameterization ))
319+
301320 def test_laplacianimage_random_forward (self ) -> None :
302321 if torch .__version__ <= "1.2.0" :
303322 raise unittest .SkipTest (
@@ -326,6 +345,13 @@ def test_laplacianimage_init(self) -> None:
326345
327346
328347class TestSimpleTensorParameterization (BaseTest ):
348+ def test_subclass (self ) -> None :
349+ self .assertTrue (
350+ issubclass (
351+ images .SimpleTensorParameterization , images .ImageParameterization
352+ )
353+ )
354+
329355 def test_simple_tensor_parameterization_no_grad (self ) -> None :
330356 test_input = torch .randn (1 , 3 , 4 , 4 )
331357 image_param = images .SimpleTensorParameterization (test_input )
@@ -337,6 +363,11 @@ def test_simple_tensor_parameterization_no_grad(self) -> None:
337363 self .assertFalse (image_param .tensor .requires_grad )
338364
339365 def test_simple_tensor_parameterization_jit_module_no_grad (self ) -> None :
366+ if torch .__version__ <= "1.8.0" :
367+ raise unittest .SkipTest (
368+ "Skipping SimpleTensorParameterization JIT module test due to"
369+ + " insufficient Torch version."
370+ )
340371 test_input = torch .randn (1 , 3 , 4 , 4 )
341372 image_param = images .SimpleTensorParameterization (test_input )
342373 jit_image_param = torch .jit .script (image_param )
@@ -356,6 +387,11 @@ def test_simple_tensor_parameterization_with_grad(self) -> None:
356387 self .assertTrue (image_param .tensor .requires_grad )
357388
358389 def test_simple_tensor_parameterization_jit_module_with_grad (self ) -> None :
390+ if torch .__version__ <= "1.8.0" :
391+ raise unittest .SkipTest (
392+ "Skipping SimpleTensorParameterization JIT module test due to"
393+ + " insufficient Torch version."
394+ )
359395 test_input = torch .nn .Parameter (torch .randn (1 , 3 , 4 , 4 ))
360396 image_param = images .SimpleTensorParameterization (test_input )
361397 jit_image_param = torch .jit .script (image_param )
@@ -383,6 +419,11 @@ def test_simple_tensor_parameterization_cuda(self) -> None:
383419
384420
385421class TestSharedImage (BaseTest ):
422+ def test_subclass (self ) -> None :
423+ self .assertTrue (
424+ issubclass (images .SharedImage , images .AugmentedImageParameterization )
425+ )
426+
386427 def test_sharedimage_get_offset_single_number (self ) -> None :
387428 if torch .__version__ <= "1.2.0" :
388429 raise unittest .SkipTest (
@@ -726,6 +767,11 @@ def test_sharedimage_multiple_shapes_diff_len_forward_jit_module(self) -> None:
726767
727768
728769class TestStackImage (BaseTest ):
770+ def test_subclass (self ) -> None :
771+ self .assertTrue (
772+ issubclass (images .StackImage , images .AugmentedImageParameterization )
773+ )
774+
729775 def test_stackimage_init (self ) -> None :
730776 if torch .__version__ <= "1.2.0" :
731777 raise unittest .SkipTest (
@@ -940,6 +986,9 @@ def test_stackimage_forward_multi_device_cpu_gpu(self) -> None:
940986
941987
942988class TestNaturalImage (BaseTest ):
989+ def test_subclass (self ) -> None :
990+ self .assertTrue (issubclass (images .NaturalImage , images .ImageParameterization ))
991+
943992 def test_natural_image_0 (self ) -> None :
944993 if torch .__version__ <= "1.2.0" :
945994 raise unittest .SkipTest (
0 commit comments