|
| 1 | +from abc import ABC, abstractmethod |
| 2 | +from typing import Optional |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | + |
| 7 | +from captum.optim._utils.images import get_neuron_pos |
| 8 | +from captum.optim._utils.typing import ModuleOutputMapping |
| 9 | + |
| 10 | + |
| 11 | +class Loss(ABC): |
| 12 | + """ |
| 13 | + Abstract Class to describe loss. |
| 14 | + """ |
| 15 | + |
| 16 | + def __init__(self, target: nn.Module) -> None: |
| 17 | + super(Loss, self).__init__() |
| 18 | + self.target = target |
| 19 | + |
| 20 | + @abstractmethod |
| 21 | + def __call__(self, targets_to_values: ModuleOutputMapping): |
| 22 | + pass |
| 23 | + |
| 24 | + |
| 25 | +class LayerActivation(Loss): |
| 26 | + """ |
| 27 | + Maximize activations at the target layer. |
| 28 | + """ |
| 29 | + |
| 30 | + def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: |
| 31 | + return targets_to_values[self.target] |
| 32 | + |
| 33 | + |
| 34 | +class ChannelActivation(Loss): |
| 35 | + """ |
| 36 | + Maximize activations at the target layer and target channel. |
| 37 | + """ |
| 38 | + |
| 39 | + def __init__(self, target: nn.Module, channel_index: int) -> None: |
| 40 | + super(Loss, self).__init__() |
| 41 | + self.target = target |
| 42 | + self.channel_index = channel_index |
| 43 | + |
| 44 | + def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: |
| 45 | + activations = targets_to_values[self.target] |
| 46 | + assert activations is not None |
| 47 | + # ensure channel_index is valid |
| 48 | + assert self.channel_index < activations.shape[1] |
| 49 | + # assume NCHW |
| 50 | + # NOTE: not necessarily true e.g. for Linear layers |
| 51 | + # assert len(activations.shape) == 4 |
| 52 | + return activations[:, self.channel_index, ...] |
| 53 | + |
| 54 | + |
| 55 | +class NeuronActivation(Loss): |
| 56 | + def __init__( |
| 57 | + self, |
| 58 | + target: nn.Module, |
| 59 | + channel_index: int, |
| 60 | + x: Optional[int] = None, |
| 61 | + y: Optional[int] = None, |
| 62 | + ) -> None: |
| 63 | + super(Loss, self).__init__() |
| 64 | + self.target = target |
| 65 | + self.channel_index = channel_index |
| 66 | + self.x = x |
| 67 | + self.y = y |
| 68 | + |
| 69 | + # ensure channel_index will be valid |
| 70 | + assert self.channel_index < self.target.out_channels |
| 71 | + |
| 72 | + def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: |
| 73 | + activations = targets_to_values[self.target] |
| 74 | + assert activations is not None |
| 75 | + assert len(activations.shape) == 4 # assume NCHW |
| 76 | + _x, _y = get_neuron_pos( |
| 77 | + activations.size(2), activations.size(3), self.x, self.y |
| 78 | + ) |
| 79 | + |
| 80 | + return activations[:, self.channel_index, _x, _y] |
| 81 | + |
| 82 | + |
| 83 | +class DeepDream(Loss): |
| 84 | + """ |
| 85 | + Maximize 'interestingness' at the target layer. |
| 86 | + Mordvintsev et al., 2015. |
| 87 | + """ |
| 88 | + |
| 89 | + def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: |
| 90 | + activations = targets_to_values[self.target] |
| 91 | + return activations ** 2 |
| 92 | + |
| 93 | + |
| 94 | +class TotalVariation(Loss): |
| 95 | + """ |
| 96 | + Total variation denoising penalty for activations. |
| 97 | + See Mahendran, V. 2014. Understanding Deep Image Representations by Inverting Them. |
| 98 | + https://arxiv.org/abs/1412.0035 |
| 99 | + """ |
| 100 | + |
| 101 | + def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: |
| 102 | + activations = targets_to_values[self.target] |
| 103 | + x_diff = activations[..., 1:, :] - activations[..., :-1, :] |
| 104 | + y_diff = activations[..., :, 1:] - activations[..., :, :-1] |
| 105 | + return torch.sum(torch.abs(x_diff)) + torch.sum(torch.abs(y_diff)) |
| 106 | + |
| 107 | + |
| 108 | +class L1(Loss): |
| 109 | + """ |
| 110 | + L1 norm of the target layer, generally used as a penalty. |
| 111 | + """ |
| 112 | + |
| 113 | + def __init__(self, target: nn.Module, constant: float = 0.0) -> None: |
| 114 | + super(Loss, self).__init__() |
| 115 | + self.target = target |
| 116 | + self.constant = constant |
| 117 | + |
| 118 | + def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: |
| 119 | + activations = targets_to_values[self.target] |
| 120 | + return torch.abs(activations - self.constant).sum() |
| 121 | + |
| 122 | + |
| 123 | +class L2(Loss): |
| 124 | + """ |
| 125 | + L2 norm of the target layer, generally used as a penalty. |
| 126 | + """ |
| 127 | + |
| 128 | + def __init__( |
| 129 | + self, target: nn.Module, constant: float = 0.0, epsilon: float = 1e-6 |
| 130 | + ) -> None: |
| 131 | + self.target = target |
| 132 | + self.constant = constant |
| 133 | + self.epsilon = epsilon |
| 134 | + |
| 135 | + def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: |
| 136 | + activations = targets_to_values[self.target] |
| 137 | + activations = (activations - self.constant).sum() |
| 138 | + return torch.sqrt(self.epsilon + activations) |
| 139 | + |
| 140 | + |
| 141 | +class Diversity(Loss): |
| 142 | + """ |
| 143 | + Use a cosine similarity penalty to extract features from a polysemantic neuron. |
| 144 | + Olah, Mordvintsev & Schubert, 2017. |
| 145 | + https://distill.pub/2017/feature-visualization/#diversity |
| 146 | + """ |
| 147 | + |
| 148 | + def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: |
| 149 | + activations = targets_to_values[self.target] |
| 150 | + return -sum( |
| 151 | + [ |
| 152 | + sum( |
| 153 | + [ |
| 154 | + ( |
| 155 | + torch.cosine_similarity( |
| 156 | + activations[j].view(1, -1), activations[i].view(1, -1) |
| 157 | + ) |
| 158 | + ).sum() |
| 159 | + for i in range(activations.size(0)) |
| 160 | + if i != j |
| 161 | + ] |
| 162 | + ) |
| 163 | + for j in range(activations.size(0)) |
| 164 | + ] |
| 165 | + ) / activations.size(0) |
| 166 | + |
| 167 | + |
| 168 | +class ActivationInterpolation(Loss): |
| 169 | + """ |
| 170 | + Interpolate between two different layers & channels. |
| 171 | + Olah, Mordvintsev & Schubert, 2017. |
| 172 | + https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons |
| 173 | + """ |
| 174 | + |
| 175 | + def __init__( |
| 176 | + self, |
| 177 | + target1: nn.Module = None, |
| 178 | + channel_index1: int = -1, |
| 179 | + target2: nn.Module = None, |
| 180 | + channel_index2: int = -1, |
| 181 | + ) -> None: |
| 182 | + super(Loss, self).__init__() |
| 183 | + self.target_one = target1 |
| 184 | + self.channel_index_one = channel_index1 |
| 185 | + self.target_two = target2 |
| 186 | + self.channel_index_two = channel_index2 |
| 187 | + |
| 188 | + def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: |
| 189 | + activations_one = targets_to_values[self.target_one] |
| 190 | + activations_two = targets_to_values[self.target_two] |
| 191 | + |
| 192 | + assert activations_one is not None and activations_two is not None |
| 193 | + # ensure channel indices are valid |
| 194 | + assert ( |
| 195 | + self.channel_index_one < activations_one.shape[1] |
| 196 | + and self.channel_index_two < activations_two.shape[1] |
| 197 | + ) |
| 198 | + assert activations_one.size(0) == activations_two.size(0) |
| 199 | + |
| 200 | + if self.channel_index_one > -1: |
| 201 | + activations_one = activations_one[:, self.channel_index_one] |
| 202 | + if self.channel_index_two > -1: |
| 203 | + activations_two = activations_two[:, self.channel_index_two] |
| 204 | + B = activations_one.size(0) |
| 205 | + |
| 206 | + batch_weights = torch.arange(B, device=activations_one.device) / (B - 1) |
| 207 | + sum_tensor = torch.zeros(1, device=activations_one.device) |
| 208 | + for n in range(B): |
| 209 | + sum_tensor = ( |
| 210 | + sum_tensor + ((1 - batch_weights[n]) * activations_one[n]).mean() |
| 211 | + ) |
| 212 | + sum_tensor = sum_tensor + (batch_weights[n] * activations_two[n]).mean() |
| 213 | + return sum_tensor |
| 214 | + |
| 215 | + |
| 216 | +class Alignment(Loss): |
| 217 | + """ |
| 218 | + Penalize the L2 distance between tensors in the batch to encourage visual |
| 219 | + similarity between them. |
| 220 | + Olah, Mordvintsev & Schubert, 2017. |
| 221 | + https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons |
| 222 | + """ |
| 223 | + |
| 224 | + def __init__(self, target: nn.Module, decay_ratio: float = 2.0) -> None: |
| 225 | + super(Loss, self).__init__() |
| 226 | + self.target = target |
| 227 | + self.decay_ratio = decay_ratio |
| 228 | + |
| 229 | + def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: |
| 230 | + activations = targets_to_values[self.target] |
| 231 | + B = activations.size(0) |
| 232 | + |
| 233 | + sum_tensor = torch.zeros(1, device=activations.device) |
| 234 | + for d in [1, 2, 3, 4]: |
| 235 | + for i in range(B - d): |
| 236 | + a, b = i, i + d |
| 237 | + activ_a, activ_b = activations[a], activations[b] |
| 238 | + sum_tensor = sum_tensor + ( |
| 239 | + (activ_a - activ_b) ** 2 |
| 240 | + ).mean() / self.decay_ratio ** float(d) |
| 241 | + |
| 242 | + return sum_tensor |
| 243 | + |
| 244 | + |
| 245 | +class Direction(Loss): |
| 246 | + """ |
| 247 | + Visualize a general direction vector. |
| 248 | + Carter, et al., "Activation Atlas", Distill, 2019. |
| 249 | + https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images |
| 250 | + """ |
| 251 | + |
| 252 | + def __init__(self, target: nn.Module, vec: torch.Tensor) -> None: |
| 253 | + super(Loss, self).__init__() |
| 254 | + self.target = target |
| 255 | + self.direction = vec.reshape((1, -1, 1, 1)) |
| 256 | + |
| 257 | + def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: |
| 258 | + activations = targets_to_values[self.target] |
| 259 | + assert activations.size(1) == self.direction.size(1) |
| 260 | + return torch.cosine_similarity(self.direction, activations) |
| 261 | + |
| 262 | + |
| 263 | +class DirectionNeuron(Loss): |
| 264 | + """ |
| 265 | + Visualize a single (x, y) position for a direction vector. |
| 266 | + Carter, et al., "Activation Atlas", Distill, 2019. |
| 267 | + https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images |
| 268 | + """ |
| 269 | + |
| 270 | + def __init__( |
| 271 | + self, |
| 272 | + target: nn.Module, |
| 273 | + vec: torch.Tensor, |
| 274 | + channel_index: int, |
| 275 | + x: Optional[int] = None, |
| 276 | + y: Optional[int] = None, |
| 277 | + ) -> None: |
| 278 | + super(Loss, self).__init__() |
| 279 | + self.target = target |
| 280 | + self.direction = vec.reshape((1, -1, 1, 1)) |
| 281 | + self.channel_index = channel_index |
| 282 | + self.x = x |
| 283 | + self.y = y |
| 284 | + |
| 285 | + def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: |
| 286 | + activations = targets_to_values[self.target] |
| 287 | + |
| 288 | + assert activations.dim() == 4 |
| 289 | + |
| 290 | + _x, _y = get_neuron_pos( |
| 291 | + activations.size(2), activations.size(3), self.x, self.y |
| 292 | + ) |
| 293 | + activations = activations[:, self.channel_index, _x, _y] |
| 294 | + return torch.cosine_similarity(self.direction, activations[None, None, None]) |
| 295 | + |
| 296 | + |
| 297 | +class TensorDirection(Loss): |
| 298 | + """ |
| 299 | + Visualize a tensor direction vector. |
| 300 | + Carter, et al., "Activation Atlas", Distill, 2019. |
| 301 | + https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images |
| 302 | + """ |
| 303 | + |
| 304 | + def __init__(self, target: nn.Module, vec: torch.Tensor) -> None: |
| 305 | + super(Loss, self).__init__() |
| 306 | + self.target = target |
| 307 | + self.direction = vec |
| 308 | + |
| 309 | + def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: |
| 310 | + activations = targets_to_values[self.target] |
| 311 | + |
| 312 | + assert activations.dim() == 4 |
| 313 | + |
| 314 | + H_direction, W_direction = self.direction.size(2), self.direction.size(3) |
| 315 | + H_activ, W_activ = activations.size(2), activations.size(3) |
| 316 | + |
| 317 | + H = (H_activ - H_direction) // 2 |
| 318 | + W = (W_activ - W_direction) // 2 |
| 319 | + |
| 320 | + activations = activations[:, :, H : H + H_direction, W : W + W_direction] |
| 321 | + return torch.cosine_similarity(self.direction, activations) |
| 322 | + |
| 323 | + |
| 324 | +class ActivationWeights(Loss): |
| 325 | + """ |
| 326 | + Apply weights to channels, neurons, or spots in the target. |
| 327 | + """ |
| 328 | + |
| 329 | + def __init__( |
| 330 | + self, |
| 331 | + target: nn.Module, |
| 332 | + weights: torch.Tensor = None, |
| 333 | + neuron: bool = False, |
| 334 | + x: Optional[int] = None, |
| 335 | + y: Optional[int] = None, |
| 336 | + wx: Optional[int] = None, |
| 337 | + wy: Optional[int] = None, |
| 338 | + ) -> None: |
| 339 | + super(Loss, self).__init__() |
| 340 | + self.target = target |
| 341 | + self.x = x |
| 342 | + self.y = y |
| 343 | + self.wx = wx |
| 344 | + self.wy = wy |
| 345 | + self.weights = weights |
| 346 | + self.neuron = x is not None or y is not None or neuron |
| 347 | + assert ( |
| 348 | + wx is None |
| 349 | + and wy is None |
| 350 | + or wx is not None |
| 351 | + and wy is not None |
| 352 | + and x is not None |
| 353 | + and y is not None |
| 354 | + ) |
| 355 | + |
| 356 | + def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: |
| 357 | + activations = targets_to_values[self.target] |
| 358 | + if self.neuron: |
| 359 | + assert activations.dim() == 4 |
| 360 | + if self.wx is None and self.wy is None: |
| 361 | + _x, _y = get_neuron_pos( |
| 362 | + activations.size(2), activations.size(3), self.x, self.y |
| 363 | + ) |
| 364 | + activations = activations[..., _x, _y].squeeze() * self.weights |
| 365 | + else: |
| 366 | + activations = activations[ |
| 367 | + ..., self.y : self.y + self.wy, self.x : self.x + self.wx |
| 368 | + ] * self.weights.view(1, -1, 1, 1) |
| 369 | + else: |
| 370 | + activations = activations * self.weights.view(1, -1, 1, 1) |
| 371 | + return activations |
0 commit comments