@@ -193,6 +193,14 @@ def wrapper(*args, **kwargs) -> object:
193193class LayerActivation (BaseLoss ):
194194 """
195195 Maximize activations at the target layer.
196+ This is the most basic loss available and it simply returns the activations in
197+ their original form.
198+
199+ Args:
200+ target (nn.Module): The layer to optimize for.
201+ batch_index (int, optional): The index of the image to optimize if we
202+ optimizing a batch of images. If unspecified, defaults to all images
203+ in the batch.
196204 """
197205
198206 def __call__ (self , targets_to_values : ModuleOutputMapping ) -> torch .Tensor :
@@ -205,6 +213,15 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
205213class ChannelActivation (BaseLoss ):
206214 """
207215 Maximize activations at the target layer and target channel.
216+ This loss maximizes the activations of a target channel in a specified target
217+ layer, and can be useful to determine what features the channel is excited by.
218+
219+ Args:
220+ target (nn.Module): The layer to containing the channel to optimize for.
221+ channel_index (int): The index of the channel to optimize for.
222+ batch_index (int, optional): The index of the image to optimize if we
223+ optimizing a batch of images. If unspecified, defaults to all images
224+ in the batch.
208225 """
209226
210227 def __init__ (
@@ -228,6 +245,26 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
228245
229246@loss_wrapper
230247class NeuronActivation (BaseLoss ):
248+ """
249+ This loss maximizes the activations of a target neuron in the specified channel
250+ from the specified layer. This loss is useful for determining the type of features
251+ that excite a neuron, and thus is often used for circuits and neuron related
252+ research.
253+
254+ Args:
255+ target (nn.Module): The layer to containing the channel to optimize for.
256+ channel_index (int): The index of the channel to optimize for.
257+ x (int, optional): The x coordinate of the neuron to optimize for. If
258+ unspecified, defaults to center, or one unit left of center for even
259+ lengths.
260+ y (int, optional): The y coordinate of the neuron to optimize for. If
261+ unspecified, defaults to center, or one unit up of center for even
262+ heights.
263+ batch_index (int, optional): The index of the image to optimize if we
264+ optimizing a batch of images. If unspecified, defaults to all images
265+ in the batch.
266+ """
267+
231268 def __init__ (
232269 self ,
233270 target : nn .Module ,
@@ -262,6 +299,16 @@ class DeepDream(BaseLoss):
262299 """
263300 Maximize 'interestingness' at the target layer.
264301 Mordvintsev et al., 2015.
302+ https://github.com/google/deepdream
303+ This loss returns the squared layer activations. When combined with a negative
304+ mean loss summarization, this loss will create hallucinogenic visuals commonly
305+ referred to as 'Deep Dream'.
306+
307+ Args:
308+ target (nn.Module): The layer to optimize for.
309+ batch_index (int, optional): The index of the image to optimize if we
310+ optimizing a batch of images. If unspecified, defaults to all images
311+ in the batch.
265312 """
266313
267314 def __call__ (self , targets_to_values : ModuleOutputMapping ) -> torch .Tensor :
@@ -276,6 +323,15 @@ class TotalVariation(BaseLoss):
276323 Total variation denoising penalty for activations.
277324 See Mahendran, V. 2014. Understanding Deep Image Representations by Inverting Them.
278325 https://arxiv.org/abs/1412.0035
326+ This loss attempts to smooth / denoise the target by performing total variance
327+ denoising. The target is most often the image that’s being optimized. This loss is
328+ often used to remove unwanted visual artifacts.
329+
330+ Args:
331+ target (nn.Module): The layer to optimize for.
332+ batch_index (int, optional): The index of the image to optimize if we
333+ optimizing a batch of images. If unspecified, defaults to all images
334+ in the batch.
279335 """
280336
281337 def __call__ (self , targets_to_values : ModuleOutputMapping ) -> torch .Tensor :
@@ -290,6 +346,14 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
290346class L1 (BaseLoss ):
291347 """
292348 L1 norm of the target layer, generally used as a penalty.
349+
350+ Args:
351+ target (nn.Module): The layer to optimize for.
352+ constant (float): Constant threshold to deduct from the activations.
353+ Defaults to 0.
354+ batch_index (int, optional): The index of the image to optimize if we
355+ optimizing a batch of images. If unspecified, defaults to all images
356+ in the batch.
293357 """
294358
295359 def __init__ (
@@ -311,6 +375,15 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
311375class L2 (BaseLoss ):
312376 """
313377 L2 norm of the target layer, generally used as a penalty.
378+
379+ Args:
380+ target (nn.Module): The layer to optimize for.
381+ constant (float): Constant threshold to deduct from the activations.
382+ Defaults to 0.
383+ epsilon (float): Small value to add to L2 prior to sqrt. Defaults to 1e-6.
384+ batch_index (int, optional): The index of the image to optimize if we
385+ optimizing a batch of images. If unspecified, defaults to all images
386+ in the batch.
314387 """
315388
316389 def __init__ (
@@ -338,6 +411,14 @@ class Diversity(BaseLoss):
338411 Use a cosine similarity penalty to extract features from a polysemantic neuron.
339412 Olah, Mordvintsev & Schubert, 2017.
340413 https://distill.pub/2017/feature-visualization/#diversity
414+ This loss helps break up polysemantic layers, channels, and neurons by encouraging
415+ diversity across the different batches. This loss is to be used along with a main
416+ loss.
417+
418+ Args:
419+ target (nn.Module): The layer to optimize for.
420+ batch_index (int, optional): Unused here since we are optimizing for diversity
421+ across the batch.
341422 """
342423
343424 def __call__ (self , targets_to_values : ModuleOutputMapping ) -> torch .Tensor :
@@ -363,6 +444,16 @@ class ActivationInterpolation(BaseLoss):
363444 Interpolate between two different layers & channels.
364445 Olah, Mordvintsev & Schubert, 2017.
365446 https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons
447+ This loss helps to interpolate or mix visualizations from two activations (layer or
448+ channel) by interpolating a linear sum between the two activations.
449+
450+ Args:
451+ target1 (nn.Module): The first layer to optimize for.
452+ channel_index1 (int): Index of channel in first layer to optimize. Defaults to
453+ all channels.
454+ target2 (nn.Module): The first layer to optimize for.
455+ channel_index2 (int): Index of channel in first layer to optimize. Defaults to
456+ all channels.
366457 """
367458
368459 def __init__ (
@@ -414,6 +505,14 @@ class Alignment(BaseLoss):
414505 similarity between them.
415506 Olah, Mordvintsev & Schubert, 2017.
416507 https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons
508+ When interpolating between activations, it may be desirable to keep image landmarks
509+ in the same position for visual comparison. This loss helps to minimize L2 distance
510+ between neighbouring images.
511+
512+ Args:
513+ target (nn.Module): The layer to optimize for.
514+ decay_ratio (float): How much to decay penalty as images move apart in batch.
515+ Defaults to 2.
417516 """
418517
419518 def __init__ (self , target : nn .Module , decay_ratio : float = 2.0 ) -> None :
@@ -442,6 +541,18 @@ class Direction(BaseLoss):
442541 Visualize a general direction vector.
443542 Carter, et al., "Activation Atlas", Distill, 2019.
444543 https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images
544+ This loss helps to visualize a specific vector direction in a layer, by maximizing
545+ the alignment between the input vector and the layer’s activation vector. The
546+ dimensionality of the vector should correspond to the number of channels in the
547+ layer.
548+
549+ Args:
550+ target (nn.Module): The layer to optimize for.
551+ vec (torch.Tensor): Vector representing direction to align to.
552+ cossim_pow (float, optional): The desired cosine similarity power to use.
553+ batch_index (int, optional): The index of the image to optimize if we
554+ optimizing a batch of images. If unspecified, defaults to all images
555+ in the batch.
445556 """
446557
447558 def __init__ (
@@ -468,6 +579,23 @@ class NeuronDirection(BaseLoss):
468579 Visualize a single (x, y) position for a direction vector.
469580 Carter, et al., "Activation Atlas", Distill, 2019.
470581 https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images
582+ Extends Direction loss by focusing on visualizing a single neuron within the
583+ kernel.
584+
585+ Args:
586+ target (nn.Module): The layer to optimize for.
587+ vec (torch.Tensor): Vector representing direction to align to.
588+ x (int, optional): The x coordinate of the neuron to optimize for. If
589+ unspecified, defaults to center, or one unit left of center for even
590+ lengths.
591+ y (int, optional): The y coordinate of the neuron to optimize for. If
592+ unspecified, defaults to center, or one unit up of center for even
593+ heights.
594+ channel_index (int): The index of the channel to optimize for.
595+ cossim_pow (float, optional): The desired cosine similarity power to use.
596+ batch_index (int, optional): The index of the image to optimize if we
597+ optimizing a batch of images. If unspecified, defaults to all images
598+ in the batch.
471599 """
472600
473601 def __init__ (
@@ -597,6 +725,15 @@ class TensorDirection(BaseLoss):
597725 Visualize a tensor direction vector.
598726 Carter, et al., "Activation Atlas", Distill, 2019.
599727 https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images
728+ Extends Direction loss by allowing batch-wise direction visualization.
729+
730+ Args:
731+ target (nn.Module): The layer to optimize for.
732+ vec (torch.Tensor): Vector representing direction to align to.
733+ cossim_pow (float, optional): The desired cosine similarity power to use.
734+ batch_index (int, optional): The index of the image to optimize if we
735+ optimizing a batch of images. If unspecified, defaults to all images
736+ in the batch.
600737 """
601738
602739 def __init__ (
@@ -635,6 +772,23 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
635772class ActivationWeights (BaseLoss ):
636773 """
637774 Apply weights to channels, neurons, or spots in the target.
775+ This loss weighs specific channels or neurons in a given layer, via a weight
776+ vector.
777+
778+ Args:
779+ target (nn.Module): The layer to optimize for.
780+ weights (torch.Tensor): Weights to apply to targets.
781+ neuron (bool): Whether target is a neuron. Defaults to False.
782+ x (int, optional): The x coordinate of the neuron to optimize for. If
783+ unspecified, defaults to center, or one unit left of center for even
784+ lengths.
785+ y (int, optional): The y coordinate of the neuron to optimize for. If
786+ unspecified, defaults to center, or one unit up of center for even
787+ heights.
788+ wx (int, optional): Length of neurons to apply the weights to, along the
789+ x-axis.
790+ wy (int, optional): Length of neurons to apply the weights to, along the
791+ y-axis.
638792 """
639793
640794 def __init__ (
0 commit comments