Skip to content

Commit bb64aef

Browse files
authored
Merge branch 'master' into master-patch-1
2 parents efcdeda + 03f89a5 commit bb64aef

File tree

5 files changed

+386
-54
lines changed

5 files changed

+386
-54
lines changed

captum/_utils/common.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,43 @@ def _format_output(
354354
return output if is_inputs_tuple else output[0]
355355

356356

357+
@typing.overload
358+
def _format_outputs(
359+
is_multiple_inputs: Literal[False], outputs: List[Tuple[Tensor, ...]]
360+
) -> Union[Tensor, Tuple[Tensor, ...]]:
361+
...
362+
363+
364+
@typing.overload
365+
def _format_outputs(
366+
is_multiple_inputs: Literal[True], outputs: List[Tuple[Tensor, ...]]
367+
) -> List[Union[Tensor, Tuple[Tensor, ...]]]:
368+
...
369+
370+
371+
@typing.overload
372+
def _format_outputs(
373+
is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]]
374+
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
375+
...
376+
377+
378+
def _format_outputs(
379+
is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]]
380+
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
381+
assert isinstance(outputs, list), "Outputs must be a list"
382+
assert is_multiple_inputs or len(outputs) == 1, (
383+
"outputs should contain multiple inputs or have a single output"
384+
f"however the number of outputs is: {len(outputs)}"
385+
)
386+
387+
return (
388+
[_format_output(len(output) > 1, output) for output in outputs]
389+
if is_multiple_inputs
390+
else _format_output(len(outputs[0]) > 1, outputs[0])
391+
)
392+
393+
357394
def _run_forward(
358395
forward_func: Callable,
359396
inputs: Union[Tensor, Tuple[Tensor, ...]],

captum/attr/_core/layer/layer_integrated_gradients.py

Lines changed: 155 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
#!/usr/bin/env python3
2-
import typing
3-
from typing import Any, Callable, List, Tuple, Union
2+
import functools
3+
import warnings
4+
from typing import Any, Callable, List, Tuple, Union, overload
45

56
import torch
67
from torch import Tensor
7-
from torch.nn import Module
88
from torch.nn.parallel.scatter_gather import scatter
99

1010
from captum._utils.common import (
1111
_extract_device,
1212
_format_additional_forward_args,
13-
_format_output,
13+
_format_outputs,
1414
)
1515
from captum._utils.gradient import _forward_layer_eval, _run_forward
16-
from captum._utils.typing import BaselineType, Literal, TargetType
16+
from captum._utils.typing import BaselineType, Literal, ModuleOrModuleList, TargetType
1717
from captum.attr._core.integrated_gradients import IntegratedGradients
1818
from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
1919
from captum.attr._utils.common import (
@@ -48,20 +48,33 @@ class LayerIntegratedGradients(LayerAttribution, GradientAttribution):
4848
def __init__(
4949
self,
5050
forward_func: Callable,
51-
layer: Module,
51+
layer: ModuleOrModuleList,
5252
device_ids: Union[None, List[int]] = None,
5353
multiply_by_inputs: bool = True,
5454
) -> None:
5555
r"""
5656
Args:
5757
forward_func (callable): The forward function of the model or any
5858
modification of it
59-
layer (torch.nn.Module): Layer for which attributions are computed.
60-
Output size of attribute matches this layer's input or
61-
output dimensions, depending on whether we attribute to
62-
the inputs or outputs of the layer, corresponding to
63-
the attribution of each neuron in the input or output
64-
of this layer.
59+
layer (ModuleOrModuleList):
60+
Layer or list of layers for which attributions are computed.
61+
For each layer the output size of the attribute matches
62+
this layer's input or output dimensions, depending on
63+
whether we attribute to the inputs or outputs of the
64+
layer, corresponding to the attribution of each neuron
65+
in the input or output of this layer.
66+
67+
Please note that layers to attribute on cannot be
68+
dependent on each other. That is, a subset of layers in
69+
`layer` cannot produce the inputs for another layer.
70+
71+
For example, if your model is of a simple linked-list
72+
based graph structure (think nn.Sequence), e.g. x -> l1
73+
-> l2 -> l3 -> output. If you pass in any one of those
74+
layers, you cannot pass in another due to the
75+
dependence, e.g. if you pass in l2 you cannot pass in
76+
l1 or l3.
77+
6578
device_ids (list(int)): Device ID list, necessary only if forward_func
6679
applies a DataParallel model. This allows reconstruction of
6780
intermediate outputs from batched results across devices.
@@ -86,22 +99,48 @@ def __init__(
8699
GradientAttribution.__init__(self, forward_func)
87100
self.ig = IntegratedGradients(forward_func, multiply_by_inputs)
88101

89-
@typing.overload
102+
if isinstance(layer, list) and len(layer) > 1:
103+
warnings.warn(
104+
"Multiple layers provided. Please ensure that each layer is"
105+
"**not** solely solely dependent on the outputs of"
106+
"another layer. Please refer to the documentation for more"
107+
"detail."
108+
)
109+
110+
@overload
90111
def attribute(
91112
self,
92113
inputs: Union[Tensor, Tuple[Tensor, ...]],
93-
baselines: BaselineType = None,
94-
target: TargetType = None,
95-
additional_forward_args: Any = None,
96-
n_steps: int = 50,
97-
method: str = "gausslegendre",
98-
internal_batch_size: Union[None, int] = None,
99-
return_convergence_delta: Literal[False] = False,
100-
attribute_to_layer_input: bool = False,
101-
) -> Union[Tensor, Tuple[Tensor, ...]]:
114+
baselines: BaselineType,
115+
target: TargetType,
116+
additional_forward_args: Any,
117+
n_steps: int,
118+
method: str,
119+
internal_batch_size: Union[None, int],
120+
return_convergence_delta: Literal[False],
121+
attribute_to_layer_input: bool,
122+
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
123+
...
124+
125+
@overload
126+
def attribute(
127+
self,
128+
inputs: Union[Tensor, Tuple[Tensor, ...]],
129+
baselines: BaselineType,
130+
target: TargetType,
131+
additional_forward_args: Any,
132+
n_steps: int,
133+
method: str,
134+
internal_batch_size: Union[None, int],
135+
return_convergence_delta: Literal[True],
136+
attribute_to_layer_input: bool,
137+
) -> Tuple[
138+
Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
139+
Tensor,
140+
]:
102141
...
103142

104-
@typing.overload
143+
@overload
105144
def attribute(
106145
self,
107146
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -111,10 +150,15 @@ def attribute(
111150
n_steps: int = 50,
112151
method: str = "gausslegendre",
113152
internal_batch_size: Union[None, int] = None,
114-
*,
115-
return_convergence_delta: Literal[True],
153+
return_convergence_delta: bool = False,
116154
attribute_to_layer_input: bool = False,
117-
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]:
155+
) -> Union[
156+
Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
157+
Tuple[
158+
Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
159+
Tensor,
160+
],
161+
]:
118162
...
119163

120164
@log_usage()
@@ -130,7 +174,11 @@ def attribute(
130174
return_convergence_delta: bool = False,
131175
attribute_to_layer_input: bool = False,
132176
) -> Union[
133-
Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]
177+
Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
178+
Tuple[
179+
Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
180+
Tensor,
181+
],
134182
]:
135183
r"""
136184
This method attributes the output of the model with given target index
@@ -257,16 +305,25 @@ def attribute(
257305
Default: False
258306
Returns:
259307
**attributions** or 2-element tuple of **attributions**, **delta**:
260-
- **attributions** (*tensor* or tuple of *tensors*):
308+
- **attributions** (*tensor*, tuple of *tensors* or tuple of *tensors*):
261309
Integrated gradients with respect to `layer`'s inputs or
262310
outputs. Attributions will always be the same size and
263311
dimensionality as the input or output of the given layer,
264312
depending on whether we attribute to the inputs or outputs
265313
of the layer which is decided by the input flag
266314
`attribute_to_layer_input`.
267-
Attributions are returned in a tuple if
315+
316+
For a single layer, attributions are returned in a tuple if
268317
the layer inputs / outputs contain multiple tensors,
269318
otherwise a single tensor is returned.
319+
320+
For multiple layers, attributions will always be
321+
returned as a list. Each element in this list will be
322+
equivalent to that of a single layer output, i.e. in the
323+
case that one layer, in the given layers, inputs / outputs
324+
multiple tensors: the corresponding output element will be
325+
a tuple of tensors. The ordering of the outputs will be
326+
the same order as the layers given in the constructor.
270327
- **delta** (*tensor*, returned if return_convergence_delta=True):
271328
The difference between the total approximated and true
272329
integrated gradients. This is computed using the property
@@ -298,6 +355,11 @@ def attribute(
298355
additional_forward_args
299356
)
300357

358+
def flatten_tuple(tup):
359+
return tuple(
360+
sum((list(x) if isinstance(x, (tuple, list)) else [x] for x in tup), [])
361+
)
362+
301363
if self.device_ids is None:
302364
self.device_ids = getattr(self.forward_func, "device_ids", None)
303365
inputs_layer = _forward_layer_eval(
@@ -309,6 +371,16 @@ def attribute(
309371
attribute_to_layer_input=attribute_to_layer_input,
310372
)
311373

374+
# if we have one output
375+
if not isinstance(self.layer, list):
376+
inputs_layer = (inputs_layer,)
377+
378+
num_outputs = [1 if isinstance(x, Tensor) else len(x) for x in inputs_layer]
379+
num_outputs_cumsum = torch.cumsum(
380+
torch.IntTensor([0] + num_outputs), dim=0 # type: ignore
381+
)
382+
inputs_layer = flatten_tuple(inputs_layer)
383+
312384
baselines_layer = _forward_layer_eval(
313385
self.forward_func,
314386
baselines,
@@ -317,6 +389,7 @@ def attribute(
317389
additional_forward_args=additional_forward_args,
318390
attribute_to_layer_input=attribute_to_layer_input,
319391
)
392+
baselines_layer = flatten_tuple(baselines_layer)
320393

321394
# inputs -> these inputs are scaled
322395
def gradient_func(
@@ -341,30 +414,60 @@ def gradient_func(
341414

342415
with torch.autograd.set_grad_enabled(True):
343416

344-
def layer_forward_hook(module, hook_inputs, hook_outputs=None):
417+
def layer_forward_hook(
418+
module, hook_inputs, hook_outputs=None, layer_idx=0
419+
):
345420
device = _extract_device(module, hook_inputs, hook_outputs)
346421
is_layer_tuple = (
347422
isinstance(hook_outputs, tuple)
423+
# hook_outputs is None if attribute_to_layer_input == True
348424
if hook_outputs is not None
349425
else isinstance(hook_inputs, tuple)
350426
)
427+
351428
if is_layer_tuple:
352-
return scattered_inputs_dict[device]
353-
return scattered_inputs_dict[device][0]
429+
return scattered_inputs_dict[device][
430+
num_outputs_cumsum[layer_idx] : num_outputs_cumsum[
431+
layer_idx + 1
432+
]
433+
]
434+
435+
return scattered_inputs_dict[device][num_outputs_cumsum[layer_idx]]
354436

355-
hook = None
437+
hooks = []
356438
try:
357-
if attribute_to_layer_input:
358-
hook = self.layer.register_forward_pre_hook(layer_forward_hook)
359-
else:
360-
hook = self.layer.register_forward_hook(layer_forward_hook)
439+
440+
layers = self.layer
441+
if not isinstance(layers, list):
442+
layers = [self.layer]
443+
444+
for layer_idx, layer in enumerate(layers):
445+
hook = None
446+
# TODO:
447+
# Allow multiple attribute_to_layer_input flags for
448+
# each layer, i.e. attribute_to_layer_input[layer_idx]
449+
if attribute_to_layer_input:
450+
hook = layer.register_forward_pre_hook(
451+
functools.partial(
452+
layer_forward_hook, layer_idx=layer_idx
453+
)
454+
)
455+
else:
456+
hook = layer.register_forward_hook(
457+
functools.partial(
458+
layer_forward_hook, layer_idx=layer_idx
459+
)
460+
)
461+
462+
hooks.append(hook)
361463

362464
output = _run_forward(
363465
self.forward_func, tuple(), target_ind, additional_forward_args
364466
)
365467
finally:
366-
if hook is not None:
367-
hook.remove()
468+
for hook in hooks:
469+
if hook is not None:
470+
hook.remove()
368471

369472
assert output[0].numel() == 1, (
370473
"Target not provided when necessary, cannot"
@@ -381,6 +484,7 @@ def layer_forward_hook(module, hook_inputs, hook_outputs=None):
381484
if additional_forward_args is not None
382485
else inps
383486
)
487+
384488
attributions = self.ig.attribute.__wrapped__( # type: ignore
385489
self.ig, # self
386490
inputs_layer,
@@ -393,6 +497,16 @@ def layer_forward_hook(module, hook_inputs, hook_outputs=None):
393497
return_convergence_delta=False,
394498
)
395499

500+
# handle multiple outputs
501+
output: List[Tuple[Tensor, ...]] = [
502+
tuple(
503+
attributions[
504+
int(num_outputs_cumsum[i]) : int(num_outputs_cumsum[i + 1])
505+
]
506+
)
507+
for i in range(len(num_outputs))
508+
]
509+
396510
if return_convergence_delta:
397511
start_point, end_point = baselines, inps
398512
# computes approximation error based on the completeness axiom
@@ -403,8 +517,8 @@ def layer_forward_hook(module, hook_inputs, hook_outputs=None):
403517
additional_forward_args=additional_forward_args,
404518
target=target,
405519
)
406-
return _format_output(len(attributions) > 1, attributions), delta
407-
return _format_output(len(attributions) > 1, attributions)
520+
return _format_outputs(isinstance(self.layer, list), output), delta
521+
return _format_outputs(isinstance(self.layer, list), output)
408522

409523
def has_convergence_delta(self) -> bool:
410524
return True

0 commit comments

Comments
 (0)