Skip to content

Commit 32e4f33

Browse files
authored
Merge pull request #642 from ProGamerGov/optim-wip-model-tutorial
add model tutorial
2 parents a1e0633 + c656658 commit 32e4f33

File tree

4 files changed

+479
-4
lines changed

4 files changed

+479
-4
lines changed

captum/optim/_core/loss.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,14 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
175175
return math_op(torch.mean(self(module)), torch.mean(other(module)))
176176

177177
name = f"Compose({', '.join([self.__name__, other.__name__])})"
178+
179+
# ToDo: Refine logic for self.target handling
178180
target = (self.target if isinstance(self.target, list) else [self.target]) + (
179181
other.target if isinstance(other.target, list) else [other.target]
180182
)
183+
184+
# Filter out duplicate targets
185+
target = list(dict.fromkeys(target))
181186
else:
182187
raise TypeError(
183188
"Can only apply math operations with int, float or Loss. Received type "
@@ -1398,6 +1403,9 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
13981403
]
13991404
for target in targets
14001405
]
1406+
1407+
# Filter out duplicate targets
1408+
target = list(dict.fromkeys(target))
14011409
return CompositeLoss(loss_fn, name=name, target=target)
14021410

14031411

captum/optim/models/_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ def collect_activations(
255255
"""
256256
if not isinstance(targets, list):
257257
targets = [targets]
258+
targets = list(dict.fromkeys(targets))
258259
catch_activ = ActivationFetcher(model, targets)
259260
activ_dict = catch_activ(model_input)
260261
return activ_dict

captum/optim/models/_image/inception_v1.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55
import torch.nn as nn
66
from captum.optim.models._common import Conv2dSame, RedirectedReluLayer, SkipLayer
77

8-
GS_SAVED_WEIGHTS_URL = (
9-
"https://github.com/pytorch/captum/raw/"
10-
+ "optim-wip/captum/optim/models/_image/inception5h.pth"
11-
)
8+
GS_SAVED_WEIGHTS_URL = "https://pytorch.s3.amazonaws.com/models/captum/inception5h.pth"
129

1310

1411
def googlenet(

0 commit comments

Comments
 (0)