|
5 | 5 | import numpy as np |
6 | 6 | import torch |
7 | 7 | from captum.optim._utils.reducer import posneg |
| 8 | +from packaging import version |
8 | 9 |
|
9 | 10 | try: |
10 | 11 | from PIL import Image |
@@ -64,6 +65,21 @@ def save_tensor_as_image(x: torch.Tensor, filename: str, scale: float = 255.0) - |
64 | 65 | def get_neuron_pos( |
65 | 66 | H: int, W: int, x: Optional[int] = None, y: Optional[int] = None |
66 | 67 | ) -> Tuple[int, int]: |
| 68 | + """ |
| 69 | + Args: |
| 70 | +
|
| 71 | + H (int) The height |
| 72 | + W (int) The width |
| 73 | + x (int, optional): Optionally specify and exact x location of the neuron. If |
| 74 | + set to None, then the center x location will be used. |
| 75 | + Default: None |
| 76 | + y (int, optional): Optionally specify and exact y location of the neuron. If |
| 77 | + set to None, then the center y location will be used. |
| 78 | + Default: None |
| 79 | +
|
| 80 | + Return: |
| 81 | + Tuple[_x, _y] (Tuple[int, int]): The x and y dimensions of the neuron. |
| 82 | + """ |
67 | 83 | if x is None: |
68 | 84 | _x = W // 2 |
69 | 85 | else: |
@@ -109,66 +125,93 @@ def _dot_cossim( |
109 | 125 | return dot * torch.clamp(torch.cosine_similarity(x, y, eps=eps), 0.1) ** cossim_pow |
110 | 126 |
|
111 | 127 |
|
112 | | -@torch.jit.ignore |
113 | | -def nchannels_to_rgb(x: torch.Tensor, warp: bool = True) -> torch.Tensor: |
114 | | - """ |
115 | | - Convert an NCHW image with n channels into a 3 channel RGB image. |
| 128 | +# Handle older versions of PyTorch |
| 129 | +# Defined outside of function in order to support JIT |
| 130 | +_torch_norm = ( |
| 131 | + torch.linalg.norm |
| 132 | + if version.parse(torch.__version__) >= version.parse("1.7.0") |
| 133 | + else torch.norm |
| 134 | +) |
116 | 135 |
|
| 136 | + |
| 137 | +def hue_to_rgb( |
| 138 | + angle: float, device: torch.device = torch.device("cpu"), warp: bool = True |
| 139 | +) -> torch.Tensor: |
| 140 | + """ |
| 141 | + Create an RGB unit vector based on a hue of the input angle. |
117 | 142 | Args: |
118 | | - x (torch.Tensor): Image tensor to transform into RGB image. |
119 | | - warp (bool, optional): Whether or not to make colors more distinguishable. |
| 143 | + angle (float): The hue angle to create an RGB color for. |
| 144 | + device (torch.device, optional): The device to create the angle color tensor |
| 145 | + on. |
| 146 | + Default: torch.device("cpu") |
| 147 | + warp (bool, optional): Whether or not to make colors more distinguishable. |
120 | 148 | Default: True |
121 | 149 | Returns: |
122 | | - *tensor* RGB image |
| 150 | + color_vec (torch.Tensor): A color vector. |
123 | 151 | """ |
124 | 152 |
|
125 | | - def hue_to_rgb(angle: float) -> torch.Tensor: |
126 | | - """ |
127 | | - Create an RGB unit vector based on a hue of the input angle. |
128 | | - """ |
129 | | - |
130 | | - angle = angle - 360 * (angle // 360) |
131 | | - colors = torch.tensor( |
132 | | - [ |
133 | | - [1.0, 0.0, 0.0], |
134 | | - [0.7071, 0.7071, 0.0], |
135 | | - [0.0, 1.0, 0.0], |
136 | | - [0.0, 0.7071, 0.7071], |
137 | | - [0.0, 0.0, 1.0], |
138 | | - [0.7071, 0.0, 0.7071], |
139 | | - ] |
| 153 | + angle = angle - 360 * (angle // 360) |
| 154 | + colors = torch.tensor( |
| 155 | + [ |
| 156 | + [1.0, 0.0, 0.0], |
| 157 | + [0.7071, 0.7071, 0.0], |
| 158 | + [0.0, 1.0, 0.0], |
| 159 | + [0.0, 0.7071, 0.7071], |
| 160 | + [0.0, 0.0, 1.0], |
| 161 | + [0.7071, 0.0, 0.7071], |
| 162 | + ], |
| 163 | + device=device, |
| 164 | + ) |
| 165 | + |
| 166 | + idx = math.floor(angle / 60) |
| 167 | + d = (angle - idx * 60) / 60 |
| 168 | + |
| 169 | + if warp: |
| 170 | + # Idea from: https://github.com/tensorflow/lucid/pull/193 |
| 171 | + d = ( |
| 172 | + math.sin(d * math.pi / 2) |
| 173 | + if idx % 2 == 0 |
| 174 | + else 1 - math.sin((1 - d) * math.pi / 2) |
140 | 175 | ) |
141 | 176 |
|
142 | | - idx = math.floor(angle / 60) |
143 | | - d = (angle - idx * 60) / 60 |
| 177 | + vec = (1 - d) * colors[idx] + d * colors[(idx + 1) % 6] |
| 178 | + return vec / _torch_norm(vec) |
144 | 179 |
|
145 | | - if warp: |
146 | 180 |
|
147 | | - def adj(x: float) -> float: |
148 | | - return math.sin(x * math.pi / 2) |
| 181 | +def nchannels_to_rgb( |
| 182 | + x: torch.Tensor, warp: bool = True, eps: float = 1e-4 |
| 183 | +) -> torch.Tensor: |
| 184 | + """ |
| 185 | + Convert an NCHW image with n channels into a 3 channel RGB image. |
149 | 186 |
|
150 | | - d = adj(d) if idx % 2 == 0 else 1 - adj(1 - d) |
| 187 | + Args: |
151 | 188 |
|
152 | | - vec = (1 - d) * colors[idx] + d * colors[(idx + 1) % 6] |
153 | | - return vec / torch.norm(vec) |
| 189 | + x (torch.Tensor): NCHW image tensor to transform into RGB image. |
| 190 | + warp (bool, optional): Whether or not to make colors more distinguishable. |
| 191 | + Default: True |
| 192 | + eps (float, optional): An optional epsilon value. |
| 193 | + Default: 1e-4 |
| 194 | + Returns: |
| 195 | + tensor (torch.Tensor): An NCHW RGB image tensor. |
| 196 | + """ |
154 | 197 |
|
155 | 198 | assert x.dim() == 4 |
156 | 199 |
|
157 | 200 | if (x < 0).any(): |
158 | 201 | x = posneg(x.permute(0, 2, 3, 1), -1).permute(0, 3, 1, 2) |
159 | 202 |
|
160 | 203 | rgb = torch.zeros(1, 3, x.size(2), x.size(3), device=x.device) |
161 | | - nc = x.size(1) |
162 | | - for i in range(nc): |
163 | | - rgb = rgb + x[:, i][:, None, :, :] |
164 | | - rgb = rgb * hue_to_rgb(360 * i / nc).to(device=x.device)[None, :, None, None] |
165 | | - |
166 | | - rgb = rgb + torch.ones(x.size(2), x.size(3))[None, None, :, :] * ( |
167 | | - torch.sum(x, 1)[:, None] - torch.max(x, 1)[0][:, None] |
168 | | - ) |
169 | | - return (rgb / (1e-4 + torch.norm(rgb, dim=1, keepdim=True))) * torch.norm( |
170 | | - x, dim=1, keepdim=True |
| 204 | + num_channels = x.size(1) |
| 205 | + for i in range(num_channels): |
| 206 | + rgb_angle = hue_to_rgb(360 * i / num_channels, device=x.device, warp=warp) |
| 207 | + rgb = rgb + (x[:, i][:, None, :, :] * rgb_angle[None, :, None, None]) |
| 208 | + |
| 209 | + rgb = rgb + ( |
| 210 | + torch.ones(1, 1, x.size(2), x.size(3), device=x.device) |
| 211 | + * (torch.sum(x, 1) - torch.max(x, 1)[0])[:, None] |
171 | 212 | ) |
| 213 | + rgb = rgb / (eps + _torch_norm(rgb, dim=1, keepdim=True)) |
| 214 | + return rgb * _torch_norm(x, dim=1, keepdim=True) |
172 | 215 |
|
173 | 216 |
|
174 | 217 | def weights_to_heatmap_2d( |
|
0 commit comments