Skip to content

Commit 5c92e09

Browse files
authored
Optim-wip: Add descriptions and argument documentation to losses (#831)
* Add explanations to losses * Add argument documentation for losses * Lint fix
1 parent ba84783 commit 5c92e09

File tree

1 file changed

+154
-0
lines changed

1 file changed

+154
-0
lines changed

captum/optim/_core/loss.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,14 @@ def wrapper(*args, **kwargs) -> object:
193193
class LayerActivation(BaseLoss):
194194
"""
195195
Maximize activations at the target layer.
196+
This is the most basic loss available and it simply returns the activations in
197+
their original form.
198+
199+
Args:
200+
target (nn.Module): The layer to optimize for.
201+
batch_index (int, optional): The index of the image to optimize if we
202+
optimizing a batch of images. If unspecified, defaults to all images
203+
in the batch.
196204
"""
197205

198206
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
@@ -205,6 +213,15 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
205213
class ChannelActivation(BaseLoss):
206214
"""
207215
Maximize activations at the target layer and target channel.
216+
This loss maximizes the activations of a target channel in a specified target
217+
layer, and can be useful to determine what features the channel is excited by.
218+
219+
Args:
220+
target (nn.Module): The layer to containing the channel to optimize for.
221+
channel_index (int): The index of the channel to optimize for.
222+
batch_index (int, optional): The index of the image to optimize if we
223+
optimizing a batch of images. If unspecified, defaults to all images
224+
in the batch.
208225
"""
209226

210227
def __init__(
@@ -228,6 +245,26 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
228245

229246
@loss_wrapper
230247
class NeuronActivation(BaseLoss):
248+
"""
249+
This loss maximizes the activations of a target neuron in the specified channel
250+
from the specified layer. This loss is useful for determining the type of features
251+
that excite a neuron, and thus is often used for circuits and neuron related
252+
research.
253+
254+
Args:
255+
target (nn.Module): The layer to containing the channel to optimize for.
256+
channel_index (int): The index of the channel to optimize for.
257+
x (int, optional): The x coordinate of the neuron to optimize for. If
258+
unspecified, defaults to center, or one unit left of center for even
259+
lengths.
260+
y (int, optional): The y coordinate of the neuron to optimize for. If
261+
unspecified, defaults to center, or one unit up of center for even
262+
heights.
263+
batch_index (int, optional): The index of the image to optimize if we
264+
optimizing a batch of images. If unspecified, defaults to all images
265+
in the batch.
266+
"""
267+
231268
def __init__(
232269
self,
233270
target: nn.Module,
@@ -262,6 +299,16 @@ class DeepDream(BaseLoss):
262299
"""
263300
Maximize 'interestingness' at the target layer.
264301
Mordvintsev et al., 2015.
302+
https://github.com/google/deepdream
303+
This loss returns the squared layer activations. When combined with a negative
304+
mean loss summarization, this loss will create hallucinogenic visuals commonly
305+
referred to as 'Deep Dream'.
306+
307+
Args:
308+
target (nn.Module): The layer to optimize for.
309+
batch_index (int, optional): The index of the image to optimize if we
310+
optimizing a batch of images. If unspecified, defaults to all images
311+
in the batch.
265312
"""
266313

267314
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
@@ -276,6 +323,15 @@ class TotalVariation(BaseLoss):
276323
Total variation denoising penalty for activations.
277324
See Mahendran, V. 2014. Understanding Deep Image Representations by Inverting Them.
278325
https://arxiv.org/abs/1412.0035
326+
This loss attempts to smooth / denoise the target by performing total variance
327+
denoising. The target is most often the image that’s being optimized. This loss is
328+
often used to remove unwanted visual artifacts.
329+
330+
Args:
331+
target (nn.Module): The layer to optimize for.
332+
batch_index (int, optional): The index of the image to optimize if we
333+
optimizing a batch of images. If unspecified, defaults to all images
334+
in the batch.
279335
"""
280336

281337
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
@@ -290,6 +346,14 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
290346
class L1(BaseLoss):
291347
"""
292348
L1 norm of the target layer, generally used as a penalty.
349+
350+
Args:
351+
target (nn.Module): The layer to optimize for.
352+
constant (float): Constant threshold to deduct from the activations.
353+
Defaults to 0.
354+
batch_index (int, optional): The index of the image to optimize if we
355+
optimizing a batch of images. If unspecified, defaults to all images
356+
in the batch.
293357
"""
294358

295359
def __init__(
@@ -311,6 +375,15 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
311375
class L2(BaseLoss):
312376
"""
313377
L2 norm of the target layer, generally used as a penalty.
378+
379+
Args:
380+
target (nn.Module): The layer to optimize for.
381+
constant (float): Constant threshold to deduct from the activations.
382+
Defaults to 0.
383+
epsilon (float): Small value to add to L2 prior to sqrt. Defaults to 1e-6.
384+
batch_index (int, optional): The index of the image to optimize if we
385+
optimizing a batch of images. If unspecified, defaults to all images
386+
in the batch.
314387
"""
315388

316389
def __init__(
@@ -338,6 +411,14 @@ class Diversity(BaseLoss):
338411
Use a cosine similarity penalty to extract features from a polysemantic neuron.
339412
Olah, Mordvintsev & Schubert, 2017.
340413
https://distill.pub/2017/feature-visualization/#diversity
414+
This loss helps break up polysemantic layers, channels, and neurons by encouraging
415+
diversity across the different batches. This loss is to be used along with a main
416+
loss.
417+
418+
Args:
419+
target (nn.Module): The layer to optimize for.
420+
batch_index (int, optional): Unused here since we are optimizing for diversity
421+
across the batch.
341422
"""
342423

343424
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
@@ -363,6 +444,16 @@ class ActivationInterpolation(BaseLoss):
363444
Interpolate between two different layers & channels.
364445
Olah, Mordvintsev & Schubert, 2017.
365446
https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons
447+
This loss helps to interpolate or mix visualizations from two activations (layer or
448+
channel) by interpolating a linear sum between the two activations.
449+
450+
Args:
451+
target1 (nn.Module): The first layer to optimize for.
452+
channel_index1 (int): Index of channel in first layer to optimize. Defaults to
453+
all channels.
454+
target2 (nn.Module): The first layer to optimize for.
455+
channel_index2 (int): Index of channel in first layer to optimize. Defaults to
456+
all channels.
366457
"""
367458

368459
def __init__(
@@ -414,6 +505,14 @@ class Alignment(BaseLoss):
414505
similarity between them.
415506
Olah, Mordvintsev & Schubert, 2017.
416507
https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons
508+
When interpolating between activations, it may be desirable to keep image landmarks
509+
in the same position for visual comparison. This loss helps to minimize L2 distance
510+
between neighbouring images.
511+
512+
Args:
513+
target (nn.Module): The layer to optimize for.
514+
decay_ratio (float): How much to decay penalty as images move apart in batch.
515+
Defaults to 2.
417516
"""
418517

419518
def __init__(self, target: nn.Module, decay_ratio: float = 2.0) -> None:
@@ -442,6 +541,18 @@ class Direction(BaseLoss):
442541
Visualize a general direction vector.
443542
Carter, et al., "Activation Atlas", Distill, 2019.
444543
https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images
544+
This loss helps to visualize a specific vector direction in a layer, by maximizing
545+
the alignment between the input vector and the layer’s activation vector. The
546+
dimensionality of the vector should correspond to the number of channels in the
547+
layer.
548+
549+
Args:
550+
target (nn.Module): The layer to optimize for.
551+
vec (torch.Tensor): Vector representing direction to align to.
552+
cossim_pow (float, optional): The desired cosine similarity power to use.
553+
batch_index (int, optional): The index of the image to optimize if we
554+
optimizing a batch of images. If unspecified, defaults to all images
555+
in the batch.
445556
"""
446557

447558
def __init__(
@@ -468,6 +579,23 @@ class NeuronDirection(BaseLoss):
468579
Visualize a single (x, y) position for a direction vector.
469580
Carter, et al., "Activation Atlas", Distill, 2019.
470581
https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images
582+
Extends Direction loss by focusing on visualizing a single neuron within the
583+
kernel.
584+
585+
Args:
586+
target (nn.Module): The layer to optimize for.
587+
vec (torch.Tensor): Vector representing direction to align to.
588+
x (int, optional): The x coordinate of the neuron to optimize for. If
589+
unspecified, defaults to center, or one unit left of center for even
590+
lengths.
591+
y (int, optional): The y coordinate of the neuron to optimize for. If
592+
unspecified, defaults to center, or one unit up of center for even
593+
heights.
594+
channel_index (int): The index of the channel to optimize for.
595+
cossim_pow (float, optional): The desired cosine similarity power to use.
596+
batch_index (int, optional): The index of the image to optimize if we
597+
optimizing a batch of images. If unspecified, defaults to all images
598+
in the batch.
471599
"""
472600

473601
def __init__(
@@ -597,6 +725,15 @@ class TensorDirection(BaseLoss):
597725
Visualize a tensor direction vector.
598726
Carter, et al., "Activation Atlas", Distill, 2019.
599727
https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images
728+
Extends Direction loss by allowing batch-wise direction visualization.
729+
730+
Args:
731+
target (nn.Module): The layer to optimize for.
732+
vec (torch.Tensor): Vector representing direction to align to.
733+
cossim_pow (float, optional): The desired cosine similarity power to use.
734+
batch_index (int, optional): The index of the image to optimize if we
735+
optimizing a batch of images. If unspecified, defaults to all images
736+
in the batch.
600737
"""
601738

602739
def __init__(
@@ -635,6 +772,23 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
635772
class ActivationWeights(BaseLoss):
636773
"""
637774
Apply weights to channels, neurons, or spots in the target.
775+
This loss weighs specific channels or neurons in a given layer, via a weight
776+
vector.
777+
778+
Args:
779+
target (nn.Module): The layer to optimize for.
780+
weights (torch.Tensor): Weights to apply to targets.
781+
neuron (bool): Whether target is a neuron. Defaults to False.
782+
x (int, optional): The x coordinate of the neuron to optimize for. If
783+
unspecified, defaults to center, or one unit left of center for even
784+
lengths.
785+
y (int, optional): The y coordinate of the neuron to optimize for. If
786+
unspecified, defaults to center, or one unit up of center for even
787+
heights.
788+
wx (int, optional): Length of neurons to apply the weights to, along the
789+
x-axis.
790+
wy (int, optional): Length of neurons to apply the weights to, along the
791+
y-axis.
638792
"""
639793

640794
def __init__(

0 commit comments

Comments
 (0)