Skip to content

Commit f8da94d

Browse files
authored
Replace some image tests with NumPy comparison tests
1 parent 9a79c78 commit f8da94d

File tree

2 files changed

+227
-11
lines changed

2 files changed

+227
-11
lines changed

tests/optim/helpers/numpy_image.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import numpy as np
2+
import torch
3+
4+
5+
def setup_batch(x: np.ndarray, batch: int = 1, dim: int = 3) -> np.ndarray:
6+
assert batch > 0
7+
x = x[None, :] if x.ndim == dim and batch == 1 else x
8+
x = (
9+
np.stack([np.copy(x) for b in range(batch)])
10+
if x.ndim == dim and batch > 1
11+
else x
12+
)
13+
return x
14+
15+
16+
class FFTImage(object):
17+
"""Parameterize an image using inverse real 2D FFT"""
18+
19+
def __init__(
20+
self,
21+
size=None,
22+
channels: int = 3,
23+
batch: int = 1,
24+
init=None,
25+
) -> None:
26+
super().__init__()
27+
if init is None:
28+
assert len(size) == 2
29+
self.size = size
30+
else:
31+
assert init.ndim == 3 or init.ndim == 4
32+
self.size = (
33+
(init.shape[1], init.shape[2])
34+
if init.ndim == 3
35+
else (init.shape[2], init.shape[3])
36+
)
37+
38+
frequencies = FFTImage.rfft2d_freqs(*self.size)
39+
scale = 1.0 / np.maximum(
40+
frequencies,
41+
np.full_like(frequencies, 1.0 / (max(self.size[0], self.size[1]))),
42+
)
43+
scale = scale * ((self.size[0] * self.size[1]) ** (1 / 2))
44+
spectrum_scale = scale[None, :, :, None]
45+
self.spectrum_scale = spectrum_scale
46+
47+
if init is None:
48+
coeffs_shape = (channels, self.size[0], self.size[1] // 2 + 1, 2)
49+
random_coeffs = np.random.randn(
50+
*coeffs_shape
51+
) # names=["C", "H_f", "W_f", "complex"]
52+
fourier_coeffs = random_coeffs / 50
53+
else:
54+
fourier_coeffs = (
55+
torch.rfft(torch.from_numpy(init), signal_ndim=2).numpy()
56+
/ spectrum_scale
57+
)
58+
fourier_coeffs = fourier_coeffs / spectrum_scale
59+
60+
fourier_coeffs = setup_batch(fourier_coeffs, batch, 4)
61+
self.fourier_coeffs = fourier_coeffs
62+
63+
@staticmethod
64+
def rfft2d_freqs(height: int, width: int) -> np.ndarray:
65+
"""Computes 2D spectrum frequencies."""
66+
fy = np.fft.fftfreq(height)[:, None]
67+
# on odd input dimensions we need to keep one additional frequency
68+
wadd = 2 if width % 2 == 1 else 1
69+
fx = np.fft.fftfreq(width)[: width // 2 + wadd]
70+
return np.sqrt((fx * fx) + (fy * fy))
71+
72+
def set_image(self, correlated_image: np.ndarray) -> None:
73+
coeffs = torch.rfft(torch.from_numpy(correlated_image), signal_ndim=2).numpy()
74+
self.fourier_coeffs = coeffs / self.spectrum_scale
75+
76+
def forward(self) -> np.ndarray:
77+
h, w = self.size
78+
scaled_spectrum = self.fourier_coeffs * self.spectrum_scale
79+
output = torch.irfft(torch.from_numpy(scaled_spectrum), signal_ndim=2)[
80+
:, :, :h, :w
81+
]
82+
return output.numpy()

tests/optim/param/test_images.py

Lines changed: 145 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,163 @@
11
#!/usr/bin/env python3
22
import unittest
33

4+
import numpy as np
45
import torch
56

67
from captum.optim._param.image import images
7-
from tests.helpers.basic import BaseTest, assertTensorAlmostEqual
8+
from tests.helpers.basic import (
9+
BaseTest,
10+
assertArraysAlmostEqual,
11+
assertTensorAlmostEqual,
12+
)
13+
from tests.optim.helpers import numpy_image
14+
15+
16+
class TestSetupBatch(BaseTest):
17+
def test_setup_batch_chw(self) -> None:
18+
init = torch.randn(3, 4, 4)
19+
20+
batch_test = images.ImageParameterization()
21+
tensor_wbatch = batch_test.setup_batch(init)
22+
array_wbatch = numpy_image.setup_batch(init.numpy())
23+
24+
assertArraysAlmostEqual(tensor_wbatch.numpy(), array_wbatch)
25+
26+
def test_setup_batch_chwr(self) -> None:
27+
init = torch.randn(3, 4, 4, 2)
28+
29+
batch_test = images.ImageParameterization()
30+
tensor_wbatch = batch_test.setup_batch(init, dim=4)
31+
array_wbatch = numpy_image.setup_batch(init.numpy(), dim=4)
32+
33+
assertArraysAlmostEqual(tensor_wbatch.numpy(), array_wbatch)
34+
35+
def test_setup_batch_init(self) -> None:
36+
init = torch.randn(5, 3, 4, 4)
37+
38+
batch_test = images.ImageParameterization()
39+
tensor_wbatch = batch_test.setup_batch(init, dim=3)
40+
array_wbatch = numpy_image.setup_batch(init.numpy(), dim=3)
41+
42+
assertArraysAlmostEqual(tensor_wbatch.numpy(), array_wbatch)
843

944

1045
class TestFFTImage(BaseTest):
1146
def test_pytorch_fftfreq(self) -> None:
12-
assertTensorAlmostEqual(
13-
self,
14-
images.FFTImage.pytorch_fftfreq(4, 4),
15-
torch.tensor([0.0000, 0.0625, -0.1250, -0.0625]),
16-
0,
47+
assertArraysAlmostEqual(
48+
images.FFTImage.pytorch_fftfreq(4, 4).numpy(), np.fft.fftfreq(4, 4)
1749
)
1850

1951
def test_rfft2d_freqs(self) -> None:
20-
assertTensorAlmostEqual(
21-
self,
22-
images.FFTImage.rfft2d_freqs(height=2, width=3),
23-
torch.tensor([[0.0000, 0.3333, 0.3333], [0.5000, 0.6009, 0.6009]]),
24-
delta=0.0002,
52+
height = 2
53+
width = 3
54+
assertArraysAlmostEqual(
55+
images.FFTImage.rfft2d_freqs(height, width).numpy(),
56+
numpy_image.FFTImage.rfft2d_freqs(height, width),
2557
)
2658

59+
def test_fftimage_forward_randn_init(self) -> None:
60+
if torch.__version__ == "1.2.0":
61+
raise unittest.SkipTest(
62+
"Skipping FFTImage test due to insufficient Torch version."
63+
)
64+
size = (224, 224)
65+
66+
fftimage = images.FFTImage(size=size)
67+
fftimage_np = numpy_image.FFTImage(size=size)
68+
69+
fftimage_tensor = fftimage.forward()
70+
fftimage_array = fftimage_np.forward()
71+
72+
self.assertEqual(fftimage_tensor.detach().numpy().shape, fftimage_array.shape)
73+
74+
def test_fftimage_forward_init_randn_batch(self) -> None:
75+
if torch.__version__ == "1.2.0":
76+
raise unittest.SkipTest(
77+
"Skipping FFTImage test due to insufficient Torch version."
78+
)
79+
size = (224, 224)
80+
batch = 5
81+
82+
fftimage = images.FFTImage(size=size, batch=batch)
83+
fftimage_np = numpy_image.FFTImage(size=size, batch=batch)
84+
85+
fftimage_tensor = fftimage.forward()
86+
fftimage_array = fftimage_np.forward()
87+
88+
self.assertEqual(fftimage_tensor.detach().numpy().shape, fftimage_array.shape)
89+
90+
def test_fftimage_forward_init_randn_channels(self) -> None:
91+
if torch.__version__ == "1.2.0":
92+
raise unittest.SkipTest(
93+
"Skipping FFTImage test due to insufficient Torch version."
94+
)
95+
size = (224, 224)
96+
channels = 4
97+
98+
fftimage = images.FFTImage(size=size, channels=channels)
99+
fftimage_np = numpy_image.FFTImage(size=size, channels=channels)
100+
101+
fftimage_tensor = fftimage.forward()
102+
fftimage_array = fftimage_np.forward()
103+
104+
self.assertEqual(fftimage_tensor.detach().numpy().shape, fftimage_array.shape)
105+
106+
def test_fftimage_forward_init_chw(self) -> None:
107+
if torch.__version__ == "1.2.0":
108+
raise unittest.SkipTest(
109+
"Skipping FFTImage test due to insufficient Torch version."
110+
)
111+
size = (224, 224)
112+
init_tensor = torch.randn(3, 224, 224)
113+
init_array = init_tensor.numpy()
114+
115+
fftimage = images.FFTImage(size=size, init=init_tensor)
116+
fftimage_np = numpy_image.FFTImage(size=size, init=init_array)
117+
118+
fftimage_tensor = fftimage.forward()
119+
fftimage_array = fftimage_np.forward()
120+
121+
self.assertEqual(fftimage_tensor.detach().numpy().shape, fftimage_array.shape)
122+
assertArraysAlmostEqual(fftimage_tensor.detach().numpy(), fftimage_array)
123+
124+
def test_fftimage_forward_init_bchw(self) -> None:
125+
if torch.__version__ == "1.2.0":
126+
raise unittest.SkipTest(
127+
"Skipping FFTImage test due to insufficient Torch version."
128+
)
129+
size = (224, 224)
130+
init_tensor = torch.randn(1, 3, 224, 224)
131+
init_array = init_tensor.numpy()
132+
133+
fftimage = images.FFTImage(size=size, init=init_tensor)
134+
fftimage_np = numpy_image.FFTImage(size=size, init=init_array)
135+
136+
fftimage_tensor = fftimage.forward()
137+
fftimage_array = fftimage_np.forward()
138+
139+
self.assertEqual(fftimage_tensor.detach().numpy().shape, fftimage_array.shape)
140+
assertArraysAlmostEqual(fftimage_tensor.detach().numpy(), fftimage_array)
141+
142+
def test_fftimage_forward_init_batch(self) -> None:
143+
if torch.__version__ == "1.2.0":
144+
raise unittest.SkipTest(
145+
"Skipping FFTImage test due to insufficient Torch version."
146+
)
147+
size = (224, 224)
148+
batch = 5
149+
init_tensor = torch.randn(1, 3, 224, 224)
150+
init_array = init_tensor.numpy()
151+
152+
fftimage = images.FFTImage(size=size, batch=batch, init=init_tensor)
153+
fftimage_np = numpy_image.FFTImage(size=size, batch=batch, init=init_array)
154+
155+
fftimage_tensor = fftimage.forward()
156+
fftimage_array = fftimage_np.forward()
157+
158+
self.assertEqual(fftimage_tensor.detach().numpy().shape, fftimage_array.shape)
159+
assertArraysAlmostEqual(fftimage_tensor.detach().numpy(), fftimage_array)
160+
27161

28162
class TestPixelImage(BaseTest):
29163
def test_pixelimage_random(self) -> None:

0 commit comments

Comments
 (0)