Skip to content

Commit ba84783

Browse files
authored
Optim-wip: Miscellaneous Changes & Fixes (#827)
* Miscellaneous Changes & Fixes * Add missing docs * `get_model_layers`, `collect_activations`, `Conv2dSame`, & `get_neuron_pos` were all missing documentation. * Fix `utils/image` tests and add missing `_dot_cossim` tests * Fixed `image_cov` and the dataset tests. * Renamed `utils/image/dataset.py` to `utils/image/test_dataset.py` as the lack of a `test_` prefix was causing the tests not to be run. * Renamed `utils/image/common.py` to `utils/image/test_common.py` as the lack of a `test_` prefix was causing the tests not to be run. * Added missing `_dot_cossim` tests. * Fix `nchannels_to_rgb` & `Direction` assert * Moved the `hue_to_rgb` function outside of `nchannels_to_rgb` for JIT support. * Fixed `nchannels_to_rgb` and `hue_to_rgb` functions. * Fixed `Direction` loss objective assert. * Fix `image_cov` & related tests * Fix conflicts for common -> test_common.py * Merge updates from optim-wip branch * Fix test error * Fix cossim test
1 parent cf67101 commit ba84783

File tree

11 files changed

+736
-334
lines changed

11 files changed

+736
-334
lines changed

captum/optim/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from captum.optim._utils import circuits, reducer # noqa: F401
99
from captum.optim._utils.image import atlas # noqa: F401
1010
from captum.optim._utils.image.common import ( # noqa: F401
11+
hue_to_rgb,
1112
nchannels_to_rgb,
1213
save_tensor_as_image,
1314
show,
@@ -25,6 +26,7 @@
2526
"models",
2627
"reducer",
2728
"atlas",
29+
"hue_to_rgb",
2830
"nchannels_to_rgb",
2931
"save_tensor_as_image",
3032
"show",

captum/optim/_core/loss.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -452,14 +452,14 @@ def __init__(
452452
batch_index: Optional[int] = None,
453453
) -> None:
454454
BaseLoss.__init__(self, target, batch_index)
455-
self.direction = vec.reshape((1, -1, 1, 1))
455+
self.vec = vec.reshape((1, -1, 1, 1))
456456
self.cossim_pow = cossim_pow
457457

458458
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
459459
activations = targets_to_values[self.target]
460-
assert activations.size(1) == self.direction.size(1)
460+
assert activations.size(1) == self.vec.size(1)
461461
activations = activations[self.batch_index[0] : self.batch_index[1]]
462-
return _dot_cossim(self.direction, activations, cossim_pow=self.cossim_pow)
462+
return _dot_cossim(self.vec, activations, cossim_pow=self.cossim_pow)
463463

464464

465465
@loss_wrapper
@@ -481,7 +481,7 @@ def __init__(
481481
batch_index: Optional[int] = None,
482482
) -> None:
483483
BaseLoss.__init__(self, target, batch_index)
484-
self.direction = vec.reshape((1, -1, 1, 1))
484+
self.vec = vec.reshape((1, -1, 1, 1))
485485
self.x = x
486486
self.y = y
487487
self.channel_index = channel_index
@@ -500,7 +500,7 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
500500
]
501501
if self.channel_index is not None:
502502
activations = activations[:, self.channel_index, ...][:, None, ...]
503-
return _dot_cossim(self.direction, activations, cossim_pow=self.cossim_pow)
503+
return _dot_cossim(self.vec, activations, cossim_pow=self.cossim_pow)
504504

505505

506506
@loss_wrapper
@@ -607,16 +607,17 @@ def __init__(
607607
batch_index: Optional[int] = None,
608608
) -> None:
609609
BaseLoss.__init__(self, target, batch_index)
610-
self.direction = vec
610+
assert vec.dim() == 4
611+
self.vec = vec
611612
self.cossim_pow = cossim_pow
612613

613614
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
614615
activations = targets_to_values[self.target]
615616

616617
assert activations.dim() == 4
617618

618-
H_direction, W_direction = self.direction.size(2), self.direction.size(3)
619-
H_activ, W_activ = activations.size(2), activations.size(3)
619+
H_direction, W_direction = self.vec.shape[2:]
620+
H_activ, W_activ = activations.shape[2:]
620621

621622
H = (H_activ - H_direction) // 2
622623
W = (W_activ - W_direction) // 2
@@ -627,7 +628,7 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
627628
H : H + H_direction,
628629
W : W + W_direction,
629630
]
630-
return _dot_cossim(self.direction, activations, cossim_pow=self.cossim_pow)
631+
return _dot_cossim(self.vec, activations, cossim_pow=self.cossim_pow)
631632

632633

633634
@loss_wrapper

captum/optim/_param/image/transforms.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,16 @@ def klt_transform() -> torch.Tensor:
8989
**transform** (torch.Tensor): A Karhunen-Loève transform (KLT) measured on
9090
the ImageNet dataset.
9191
"""
92+
# Handle older versions of PyTorch
93+
torch_norm = (
94+
torch.linalg.norm
95+
if version.parse(torch.__version__) >= version.parse("1.7.0")
96+
else torch.norm
97+
)
98+
9299
KLT = [[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]]
93100
transform = torch.Tensor(KLT).float()
94-
transform = transform / torch.max(torch.norm(transform, dim=0))
101+
transform = transform / torch.max(torch_norm(transform, dim=0))
95102
return transform
96103

97104
@staticmethod

captum/optim/_utils/image/common.py

Lines changed: 83 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import torch
77
from captum.optim._utils.reducer import posneg
8+
from packaging import version
89

910
try:
1011
from PIL import Image
@@ -64,6 +65,21 @@ def save_tensor_as_image(x: torch.Tensor, filename: str, scale: float = 255.0) -
6465
def get_neuron_pos(
6566
H: int, W: int, x: Optional[int] = None, y: Optional[int] = None
6667
) -> Tuple[int, int]:
68+
"""
69+
Args:
70+
71+
H (int) The height
72+
W (int) The width
73+
x (int, optional): Optionally specify and exact x location of the neuron. If
74+
set to None, then the center x location will be used.
75+
Default: None
76+
y (int, optional): Optionally specify and exact y location of the neuron. If
77+
set to None, then the center y location will be used.
78+
Default: None
79+
80+
Return:
81+
Tuple[_x, _y] (Tuple[int, int]): The x and y dimensions of the neuron.
82+
"""
6783
if x is None:
6884
_x = W // 2
6985
else:
@@ -109,66 +125,93 @@ def _dot_cossim(
109125
return dot * torch.clamp(torch.cosine_similarity(x, y, eps=eps), 0.1) ** cossim_pow
110126

111127

112-
@torch.jit.ignore
113-
def nchannels_to_rgb(x: torch.Tensor, warp: bool = True) -> torch.Tensor:
114-
"""
115-
Convert an NCHW image with n channels into a 3 channel RGB image.
128+
# Handle older versions of PyTorch
129+
# Defined outside of function in order to support JIT
130+
_torch_norm = (
131+
torch.linalg.norm
132+
if version.parse(torch.__version__) >= version.parse("1.7.0")
133+
else torch.norm
134+
)
116135

136+
137+
def hue_to_rgb(
138+
angle: float, device: torch.device = torch.device("cpu"), warp: bool = True
139+
) -> torch.Tensor:
140+
"""
141+
Create an RGB unit vector based on a hue of the input angle.
117142
Args:
118-
x (torch.Tensor): Image tensor to transform into RGB image.
119-
warp (bool, optional): Whether or not to make colors more distinguishable.
143+
angle (float): The hue angle to create an RGB color for.
144+
device (torch.device, optional): The device to create the angle color tensor
145+
on.
146+
Default: torch.device("cpu")
147+
warp (bool, optional): Whether or not to make colors more distinguishable.
120148
Default: True
121149
Returns:
122-
*tensor* RGB image
150+
color_vec (torch.Tensor): A color vector.
123151
"""
124152

125-
def hue_to_rgb(angle: float) -> torch.Tensor:
126-
"""
127-
Create an RGB unit vector based on a hue of the input angle.
128-
"""
129-
130-
angle = angle - 360 * (angle // 360)
131-
colors = torch.tensor(
132-
[
133-
[1.0, 0.0, 0.0],
134-
[0.7071, 0.7071, 0.0],
135-
[0.0, 1.0, 0.0],
136-
[0.0, 0.7071, 0.7071],
137-
[0.0, 0.0, 1.0],
138-
[0.7071, 0.0, 0.7071],
139-
]
153+
angle = angle - 360 * (angle // 360)
154+
colors = torch.tensor(
155+
[
156+
[1.0, 0.0, 0.0],
157+
[0.7071, 0.7071, 0.0],
158+
[0.0, 1.0, 0.0],
159+
[0.0, 0.7071, 0.7071],
160+
[0.0, 0.0, 1.0],
161+
[0.7071, 0.0, 0.7071],
162+
],
163+
device=device,
164+
)
165+
166+
idx = math.floor(angle / 60)
167+
d = (angle - idx * 60) / 60
168+
169+
if warp:
170+
# Idea from: https://github.com/tensorflow/lucid/pull/193
171+
d = (
172+
math.sin(d * math.pi / 2)
173+
if idx % 2 == 0
174+
else 1 - math.sin((1 - d) * math.pi / 2)
140175
)
141176

142-
idx = math.floor(angle / 60)
143-
d = (angle - idx * 60) / 60
177+
vec = (1 - d) * colors[idx] + d * colors[(idx + 1) % 6]
178+
return vec / _torch_norm(vec)
144179

145-
if warp:
146180

147-
def adj(x: float) -> float:
148-
return math.sin(x * math.pi / 2)
181+
def nchannels_to_rgb(
182+
x: torch.Tensor, warp: bool = True, eps: float = 1e-4
183+
) -> torch.Tensor:
184+
"""
185+
Convert an NCHW image with n channels into a 3 channel RGB image.
149186
150-
d = adj(d) if idx % 2 == 0 else 1 - adj(1 - d)
187+
Args:
151188
152-
vec = (1 - d) * colors[idx] + d * colors[(idx + 1) % 6]
153-
return vec / torch.norm(vec)
189+
x (torch.Tensor): NCHW image tensor to transform into RGB image.
190+
warp (bool, optional): Whether or not to make colors more distinguishable.
191+
Default: True
192+
eps (float, optional): An optional epsilon value.
193+
Default: 1e-4
194+
Returns:
195+
tensor (torch.Tensor): An NCHW RGB image tensor.
196+
"""
154197

155198
assert x.dim() == 4
156199

157200
if (x < 0).any():
158201
x = posneg(x.permute(0, 2, 3, 1), -1).permute(0, 3, 1, 2)
159202

160203
rgb = torch.zeros(1, 3, x.size(2), x.size(3), device=x.device)
161-
nc = x.size(1)
162-
for i in range(nc):
163-
rgb = rgb + x[:, i][:, None, :, :]
164-
rgb = rgb * hue_to_rgb(360 * i / nc).to(device=x.device)[None, :, None, None]
165-
166-
rgb = rgb + torch.ones(x.size(2), x.size(3))[None, None, :, :] * (
167-
torch.sum(x, 1)[:, None] - torch.max(x, 1)[0][:, None]
168-
)
169-
return (rgb / (1e-4 + torch.norm(rgb, dim=1, keepdim=True))) * torch.norm(
170-
x, dim=1, keepdim=True
204+
num_channels = x.size(1)
205+
for i in range(num_channels):
206+
rgb_angle = hue_to_rgb(360 * i / num_channels, device=x.device, warp=warp)
207+
rgb = rgb + (x[:, i][:, None, :, :] * rgb_angle[None, :, None, None])
208+
209+
rgb = rgb + (
210+
torch.ones(1, 1, x.size(2), x.size(3), device=x.device)
211+
* (torch.sum(x, 1) - torch.max(x, 1)[0])[:, None]
171212
)
213+
rgb = rgb / (eps + _torch_norm(rgb, dim=1, keepdim=True))
214+
return rgb * _torch_norm(x, dim=1, keepdim=True)
172215

173216

174217
def weights_to_heatmap_2d(

0 commit comments

Comments
 (0)