Skip to content

Commit 3cf1675

Browse files
authored
Merge pull request #527 from ProGamerGov/optim-wip
Optim wip - Move & restructure loss objectives
2 parents 2d14e55 + 95dbca3 commit 3cf1675

File tree

10 files changed

+2086
-459
lines changed

10 files changed

+2086
-459
lines changed

captum/optim/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""optim submodule."""
22

3+
from captum.optim._core import loss # noqa: F401
34
from captum.optim._core import objectives # noqa: F401
45
from captum.optim._core.objectives import InputOptimization # noqa: F401
56
from captum.optim._param.image import images # noqa: F401

captum/optim/_core/loss.py

Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,371 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Optional
3+
4+
import torch
5+
import torch.nn as nn
6+
7+
from captum.optim._utils.images import get_neuron_pos
8+
from captum.optim._utils.typing import ModuleOutputMapping
9+
10+
11+
class Loss(ABC):
12+
"""
13+
Abstract Class to describe loss.
14+
"""
15+
16+
def __init__(self, target: nn.Module) -> None:
17+
super(Loss, self).__init__()
18+
self.target = target
19+
20+
@abstractmethod
21+
def __call__(self, targets_to_values: ModuleOutputMapping):
22+
pass
23+
24+
25+
class LayerActivation(Loss):
26+
"""
27+
Maximize activations at the target layer.
28+
"""
29+
30+
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
31+
return targets_to_values[self.target]
32+
33+
34+
class ChannelActivation(Loss):
35+
"""
36+
Maximize activations at the target layer and target channel.
37+
"""
38+
39+
def __init__(self, target: nn.Module, channel_index: int) -> None:
40+
super(Loss, self).__init__()
41+
self.target = target
42+
self.channel_index = channel_index
43+
44+
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
45+
activations = targets_to_values[self.target]
46+
assert activations is not None
47+
# ensure channel_index is valid
48+
assert self.channel_index < activations.shape[1]
49+
# assume NCHW
50+
# NOTE: not necessarily true e.g. for Linear layers
51+
# assert len(activations.shape) == 4
52+
return activations[:, self.channel_index, ...]
53+
54+
55+
class NeuronActivation(Loss):
56+
def __init__(
57+
self,
58+
target: nn.Module,
59+
channel_index: int,
60+
x: Optional[int] = None,
61+
y: Optional[int] = None,
62+
) -> None:
63+
super(Loss, self).__init__()
64+
self.target = target
65+
self.channel_index = channel_index
66+
self.x = x
67+
self.y = y
68+
69+
# ensure channel_index will be valid
70+
assert self.channel_index < self.target.out_channels
71+
72+
def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
73+
activations = targets_to_values[self.target]
74+
assert activations is not None
75+
assert len(activations.shape) == 4 # assume NCHW
76+
_x, _y = get_neuron_pos(
77+
activations.size(2), activations.size(3), self.x, self.y
78+
)
79+
80+
return activations[:, self.channel_index, _x, _y]
81+
82+
83+
class DeepDream(Loss):
84+
"""
85+
Maximize 'interestingness' at the target layer.
86+
Mordvintsev et al., 2015.
87+
"""
88+
89+
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
90+
activations = targets_to_values[self.target]
91+
return activations ** 2
92+
93+
94+
class TotalVariation(Loss):
95+
"""
96+
Total variation denoising penalty for activations.
97+
See Mahendran, V. 2014. Understanding Deep Image Representations by Inverting Them.
98+
https://arxiv.org/abs/1412.0035
99+
"""
100+
101+
def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
102+
activations = targets_to_values[self.target]
103+
x_diff = activations[..., 1:, :] - activations[..., :-1, :]
104+
y_diff = activations[..., :, 1:] - activations[..., :, :-1]
105+
return torch.sum(torch.abs(x_diff)) + torch.sum(torch.abs(y_diff))
106+
107+
108+
class L1(Loss):
109+
"""
110+
L1 norm of the target layer, generally used as a penalty.
111+
"""
112+
113+
def __init__(self, target: nn.Module, constant: float = 0.0) -> None:
114+
super(Loss, self).__init__()
115+
self.target = target
116+
self.constant = constant
117+
118+
def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
119+
activations = targets_to_values[self.target]
120+
return torch.abs(activations - self.constant).sum()
121+
122+
123+
class L2(Loss):
124+
"""
125+
L2 norm of the target layer, generally used as a penalty.
126+
"""
127+
128+
def __init__(
129+
self, target: nn.Module, constant: float = 0.0, epsilon: float = 1e-6
130+
) -> None:
131+
self.target = target
132+
self.constant = constant
133+
self.epsilon = epsilon
134+
135+
def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
136+
activations = targets_to_values[self.target]
137+
activations = (activations - self.constant).sum()
138+
return torch.sqrt(self.epsilon + activations)
139+
140+
141+
class Diversity(Loss):
142+
"""
143+
Use a cosine similarity penalty to extract features from a polysemantic neuron.
144+
Olah, Mordvintsev & Schubert, 2017.
145+
https://distill.pub/2017/feature-visualization/#diversity
146+
"""
147+
148+
def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
149+
activations = targets_to_values[self.target]
150+
return -sum(
151+
[
152+
sum(
153+
[
154+
(
155+
torch.cosine_similarity(
156+
activations[j].view(1, -1), activations[i].view(1, -1)
157+
)
158+
).sum()
159+
for i in range(activations.size(0))
160+
if i != j
161+
]
162+
)
163+
for j in range(activations.size(0))
164+
]
165+
) / activations.size(0)
166+
167+
168+
class ActivationInterpolation(Loss):
169+
"""
170+
Interpolate between two different layers & channels.
171+
Olah, Mordvintsev & Schubert, 2017.
172+
https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons
173+
"""
174+
175+
def __init__(
176+
self,
177+
target1: nn.Module = None,
178+
channel_index1: int = -1,
179+
target2: nn.Module = None,
180+
channel_index2: int = -1,
181+
) -> None:
182+
super(Loss, self).__init__()
183+
self.target_one = target1
184+
self.channel_index_one = channel_index1
185+
self.target_two = target2
186+
self.channel_index_two = channel_index2
187+
188+
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
189+
activations_one = targets_to_values[self.target_one]
190+
activations_two = targets_to_values[self.target_two]
191+
192+
assert activations_one is not None and activations_two is not None
193+
# ensure channel indices are valid
194+
assert (
195+
self.channel_index_one < activations_one.shape[1]
196+
and self.channel_index_two < activations_two.shape[1]
197+
)
198+
assert activations_one.size(0) == activations_two.size(0)
199+
200+
if self.channel_index_one > -1:
201+
activations_one = activations_one[:, self.channel_index_one]
202+
if self.channel_index_two > -1:
203+
activations_two = activations_two[:, self.channel_index_two]
204+
B = activations_one.size(0)
205+
206+
batch_weights = torch.arange(B, device=activations_one.device) / (B - 1)
207+
sum_tensor = torch.zeros(1, device=activations_one.device)
208+
for n in range(B):
209+
sum_tensor = (
210+
sum_tensor + ((1 - batch_weights[n]) * activations_one[n]).mean()
211+
)
212+
sum_tensor = sum_tensor + (batch_weights[n] * activations_two[n]).mean()
213+
return sum_tensor
214+
215+
216+
class Alignment(Loss):
217+
"""
218+
Penalize the L2 distance between tensors in the batch to encourage visual
219+
similarity between them.
220+
Olah, Mordvintsev & Schubert, 2017.
221+
https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons
222+
"""
223+
224+
def __init__(self, target: nn.Module, decay_ratio: float = 2.0) -> None:
225+
super(Loss, self).__init__()
226+
self.target = target
227+
self.decay_ratio = decay_ratio
228+
229+
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
230+
activations = targets_to_values[self.target]
231+
B = activations.size(0)
232+
233+
sum_tensor = torch.zeros(1, device=activations.device)
234+
for d in [1, 2, 3, 4]:
235+
for i in range(B - d):
236+
a, b = i, i + d
237+
activ_a, activ_b = activations[a], activations[b]
238+
sum_tensor = sum_tensor + (
239+
(activ_a - activ_b) ** 2
240+
).mean() / self.decay_ratio ** float(d)
241+
242+
return sum_tensor
243+
244+
245+
class Direction(Loss):
246+
"""
247+
Visualize a general direction vector.
248+
Carter, et al., "Activation Atlas", Distill, 2019.
249+
https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images
250+
"""
251+
252+
def __init__(self, target: nn.Module, vec: torch.Tensor) -> None:
253+
super(Loss, self).__init__()
254+
self.target = target
255+
self.direction = vec.reshape((1, -1, 1, 1))
256+
257+
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
258+
activations = targets_to_values[self.target]
259+
assert activations.size(1) == self.direction.size(1)
260+
return torch.cosine_similarity(self.direction, activations)
261+
262+
263+
class DirectionNeuron(Loss):
264+
"""
265+
Visualize a single (x, y) position for a direction vector.
266+
Carter, et al., "Activation Atlas", Distill, 2019.
267+
https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images
268+
"""
269+
270+
def __init__(
271+
self,
272+
target: nn.Module,
273+
vec: torch.Tensor,
274+
channel_index: int,
275+
x: Optional[int] = None,
276+
y: Optional[int] = None,
277+
) -> None:
278+
super(Loss, self).__init__()
279+
self.target = target
280+
self.direction = vec.reshape((1, -1, 1, 1))
281+
self.channel_index = channel_index
282+
self.x = x
283+
self.y = y
284+
285+
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
286+
activations = targets_to_values[self.target]
287+
288+
assert activations.dim() == 4
289+
290+
_x, _y = get_neuron_pos(
291+
activations.size(2), activations.size(3), self.x, self.y
292+
)
293+
activations = activations[:, self.channel_index, _x, _y]
294+
return torch.cosine_similarity(self.direction, activations[None, None, None])
295+
296+
297+
class TensorDirection(Loss):
298+
"""
299+
Visualize a tensor direction vector.
300+
Carter, et al., "Activation Atlas", Distill, 2019.
301+
https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images
302+
"""
303+
304+
def __init__(self, target: nn.Module, vec: torch.Tensor) -> None:
305+
super(Loss, self).__init__()
306+
self.target = target
307+
self.direction = vec
308+
309+
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
310+
activations = targets_to_values[self.target]
311+
312+
assert activations.dim() == 4
313+
314+
H_direction, W_direction = self.direction.size(2), self.direction.size(3)
315+
H_activ, W_activ = activations.size(2), activations.size(3)
316+
317+
H = (H_activ - H_direction) // 2
318+
W = (W_activ - W_direction) // 2
319+
320+
activations = activations[:, :, H : H + H_direction, W : W + W_direction]
321+
return torch.cosine_similarity(self.direction, activations)
322+
323+
324+
class ActivationWeights(Loss):
325+
"""
326+
Apply weights to channels, neurons, or spots in the target.
327+
"""
328+
329+
def __init__(
330+
self,
331+
target: nn.Module,
332+
weights: torch.Tensor = None,
333+
neuron: bool = False,
334+
x: Optional[int] = None,
335+
y: Optional[int] = None,
336+
wx: Optional[int] = None,
337+
wy: Optional[int] = None,
338+
) -> None:
339+
super(Loss, self).__init__()
340+
self.target = target
341+
self.x = x
342+
self.y = y
343+
self.wx = wx
344+
self.wy = wy
345+
self.weights = weights
346+
self.neuron = x is not None or y is not None or neuron
347+
assert (
348+
wx is None
349+
and wy is None
350+
or wx is not None
351+
and wy is not None
352+
and x is not None
353+
and y is not None
354+
)
355+
356+
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
357+
activations = targets_to_values[self.target]
358+
if self.neuron:
359+
assert activations.dim() == 4
360+
if self.wx is None and self.wy is None:
361+
_x, _y = get_neuron_pos(
362+
activations.size(2), activations.size(3), self.x, self.y
363+
)
364+
activations = activations[..., _x, _y].squeeze() * self.weights
365+
else:
366+
activations = activations[
367+
..., self.y : self.y + self.wy, self.x : self.x + self.wx
368+
] * self.weights.view(1, -1, 1, 1)
369+
else:
370+
activations = activations * self.weights.view(1, -1, 1, 1)
371+
return activations

0 commit comments

Comments
 (0)