11#!/usr/bin/env python3
2- import typing
3- from typing import Any , Callable , List , Tuple , Union
2+ import functools
3+ import warnings
4+ from typing import Any , Callable , List , Tuple , Union , overload
45
56import torch
67from torch import Tensor
7- from torch .nn import Module
88from torch .nn .parallel .scatter_gather import scatter
99
1010from captum ._utils .common import (
1111 _extract_device ,
1212 _format_additional_forward_args ,
13- _format_output ,
13+ _format_outputs ,
1414)
1515from captum ._utils .gradient import _forward_layer_eval , _run_forward
16- from captum ._utils .typing import BaselineType , Literal , TargetType
16+ from captum ._utils .typing import BaselineType , Literal , ModuleOrModuleList , TargetType
1717from captum .attr ._core .integrated_gradients import IntegratedGradients
1818from captum .attr ._utils .attribution import GradientAttribution , LayerAttribution
1919from captum .attr ._utils .common import (
@@ -48,20 +48,33 @@ class LayerIntegratedGradients(LayerAttribution, GradientAttribution):
4848 def __init__ (
4949 self ,
5050 forward_func : Callable ,
51- layer : Module ,
51+ layer : ModuleOrModuleList ,
5252 device_ids : Union [None , List [int ]] = None ,
5353 multiply_by_inputs : bool = True ,
5454 ) -> None :
5555 r"""
5656 Args:
5757 forward_func (callable): The forward function of the model or any
5858 modification of it
59- layer (torch.nn.Module): Layer for which attributions are computed.
60- Output size of attribute matches this layer's input or
61- output dimensions, depending on whether we attribute to
62- the inputs or outputs of the layer, corresponding to
63- the attribution of each neuron in the input or output
64- of this layer.
59+ layer (ModuleOrModuleList):
60+ Layer or list of layers for which attributions are computed.
61+ For each layer the output size of the attribute matches
62+ this layer's input or output dimensions, depending on
63+ whether we attribute to the inputs or outputs of the
64+ layer, corresponding to the attribution of each neuron
65+ in the input or output of this layer.
66+
67+ Please note that layers to attribute on cannot be
68+ dependent on each other. That is, a subset of layers in
69+ `layer` cannot produce the inputs for another layer.
70+
71+ For example, if your model is of a simple linked-list
72+ based graph structure (think nn.Sequence), e.g. x -> l1
73+ -> l2 -> l3 -> output. If you pass in any one of those
74+ layers, you cannot pass in another due to the
75+ dependence, e.g. if you pass in l2 you cannot pass in
76+ l1 or l3.
77+
6578 device_ids (list(int)): Device ID list, necessary only if forward_func
6679 applies a DataParallel model. This allows reconstruction of
6780 intermediate outputs from batched results across devices.
@@ -86,22 +99,48 @@ def __init__(
8699 GradientAttribution .__init__ (self , forward_func )
87100 self .ig = IntegratedGradients (forward_func , multiply_by_inputs )
88101
89- @typing .overload
102+ if isinstance (layer , list ) and len (layer ) > 1 :
103+ warnings .warn (
104+ "Multiple layers provided. Please ensure that each layer is"
105+ "**not** solely solely dependent on the outputs of"
106+ "another layer. Please refer to the documentation for more"
107+ "detail."
108+ )
109+
110+ @overload
90111 def attribute (
91112 self ,
92113 inputs : Union [Tensor , Tuple [Tensor , ...]],
93- baselines : BaselineType = None ,
94- target : TargetType = None ,
95- additional_forward_args : Any = None ,
96- n_steps : int = 50 ,
97- method : str = "gausslegendre" ,
98- internal_batch_size : Union [None , int ] = None ,
99- return_convergence_delta : Literal [False ] = False ,
100- attribute_to_layer_input : bool = False ,
101- ) -> Union [Tensor , Tuple [Tensor , ...]]:
114+ baselines : BaselineType ,
115+ target : TargetType ,
116+ additional_forward_args : Any ,
117+ n_steps : int ,
118+ method : str ,
119+ internal_batch_size : Union [None , int ],
120+ return_convergence_delta : Literal [False ],
121+ attribute_to_layer_input : bool ,
122+ ) -> Union [Tensor , Tuple [Tensor , ...], List [Union [Tensor , Tuple [Tensor , ...]]]]:
123+ ...
124+
125+ @overload
126+ def attribute (
127+ self ,
128+ inputs : Union [Tensor , Tuple [Tensor , ...]],
129+ baselines : BaselineType ,
130+ target : TargetType ,
131+ additional_forward_args : Any ,
132+ n_steps : int ,
133+ method : str ,
134+ internal_batch_size : Union [None , int ],
135+ return_convergence_delta : Literal [True ],
136+ attribute_to_layer_input : bool ,
137+ ) -> Tuple [
138+ Union [Tensor , Tuple [Tensor , ...], List [Union [Tensor , Tuple [Tensor , ...]]]],
139+ Tensor ,
140+ ]:
102141 ...
103142
104- @typing . overload
143+ @overload
105144 def attribute (
106145 self ,
107146 inputs : Union [Tensor , Tuple [Tensor , ...]],
@@ -111,10 +150,15 @@ def attribute(
111150 n_steps : int = 50 ,
112151 method : str = "gausslegendre" ,
113152 internal_batch_size : Union [None , int ] = None ,
114- * ,
115- return_convergence_delta : Literal [True ],
153+ return_convergence_delta : bool = False ,
116154 attribute_to_layer_input : bool = False ,
117- ) -> Tuple [Union [Tensor , Tuple [Tensor , ...]], Tensor ]:
155+ ) -> Union [
156+ Union [Tensor , Tuple [Tensor , ...], List [Union [Tensor , Tuple [Tensor , ...]]]],
157+ Tuple [
158+ Union [Tensor , Tuple [Tensor , ...], List [Union [Tensor , Tuple [Tensor , ...]]]],
159+ Tensor ,
160+ ],
161+ ]:
118162 ...
119163
120164 @log_usage ()
@@ -130,7 +174,11 @@ def attribute(
130174 return_convergence_delta : bool = False ,
131175 attribute_to_layer_input : bool = False ,
132176 ) -> Union [
133- Tensor , Tuple [Tensor , ...], Tuple [Union [Tensor , Tuple [Tensor , ...]], Tensor ]
177+ Union [Tensor , Tuple [Tensor , ...], List [Union [Tensor , Tuple [Tensor , ...]]]],
178+ Tuple [
179+ Union [Tensor , Tuple [Tensor , ...], List [Union [Tensor , Tuple [Tensor , ...]]]],
180+ Tensor ,
181+ ],
134182 ]:
135183 r"""
136184 This method attributes the output of the model with given target index
@@ -257,16 +305,25 @@ def attribute(
257305 Default: False
258306 Returns:
259307 **attributions** or 2-element tuple of **attributions**, **delta**:
260- - **attributions** (*tensor* or tuple of *tensors*):
308+ - **attributions** (*tensor*, tuple of *tensors* or tuple of *tensors*):
261309 Integrated gradients with respect to `layer`'s inputs or
262310 outputs. Attributions will always be the same size and
263311 dimensionality as the input or output of the given layer,
264312 depending on whether we attribute to the inputs or outputs
265313 of the layer which is decided by the input flag
266314 `attribute_to_layer_input`.
267- Attributions are returned in a tuple if
315+
316+ For a single layer, attributions are returned in a tuple if
268317 the layer inputs / outputs contain multiple tensors,
269318 otherwise a single tensor is returned.
319+
320+ For multiple layers, attributions will always be
321+ returned as a list. Each element in this list will be
322+ equivalent to that of a single layer output, i.e. in the
323+ case that one layer, in the given layers, inputs / outputs
324+ multiple tensors: the corresponding output element will be
325+ a tuple of tensors. The ordering of the outputs will be
326+ the same order as the layers given in the constructor.
270327 - **delta** (*tensor*, returned if return_convergence_delta=True):
271328 The difference between the total approximated and true
272329 integrated gradients. This is computed using the property
@@ -298,6 +355,11 @@ def attribute(
298355 additional_forward_args
299356 )
300357
358+ def flatten_tuple (tup ):
359+ return tuple (
360+ sum ((list (x ) if isinstance (x , (tuple , list )) else [x ] for x in tup ), [])
361+ )
362+
301363 if self .device_ids is None :
302364 self .device_ids = getattr (self .forward_func , "device_ids" , None )
303365 inputs_layer = _forward_layer_eval (
@@ -309,6 +371,16 @@ def attribute(
309371 attribute_to_layer_input = attribute_to_layer_input ,
310372 )
311373
374+ # if we have one output
375+ if not isinstance (self .layer , list ):
376+ inputs_layer = (inputs_layer ,)
377+
378+ num_outputs = [1 if isinstance (x , Tensor ) else len (x ) for x in inputs_layer ]
379+ num_outputs_cumsum = torch .cumsum (
380+ torch .IntTensor ([0 ] + num_outputs ), dim = 0 # type: ignore
381+ )
382+ inputs_layer = flatten_tuple (inputs_layer )
383+
312384 baselines_layer = _forward_layer_eval (
313385 self .forward_func ,
314386 baselines ,
@@ -317,6 +389,7 @@ def attribute(
317389 additional_forward_args = additional_forward_args ,
318390 attribute_to_layer_input = attribute_to_layer_input ,
319391 )
392+ baselines_layer = flatten_tuple (baselines_layer )
320393
321394 # inputs -> these inputs are scaled
322395 def gradient_func (
@@ -341,30 +414,60 @@ def gradient_func(
341414
342415 with torch .autograd .set_grad_enabled (True ):
343416
344- def layer_forward_hook (module , hook_inputs , hook_outputs = None ):
417+ def layer_forward_hook (
418+ module , hook_inputs , hook_outputs = None , layer_idx = 0
419+ ):
345420 device = _extract_device (module , hook_inputs , hook_outputs )
346421 is_layer_tuple = (
347422 isinstance (hook_outputs , tuple )
423+ # hook_outputs is None if attribute_to_layer_input == True
348424 if hook_outputs is not None
349425 else isinstance (hook_inputs , tuple )
350426 )
427+
351428 if is_layer_tuple :
352- return scattered_inputs_dict [device ]
353- return scattered_inputs_dict [device ][0 ]
429+ return scattered_inputs_dict [device ][
430+ num_outputs_cumsum [layer_idx ] : num_outputs_cumsum [
431+ layer_idx + 1
432+ ]
433+ ]
434+
435+ return scattered_inputs_dict [device ][num_outputs_cumsum [layer_idx ]]
354436
355- hook = None
437+ hooks = []
356438 try :
357- if attribute_to_layer_input :
358- hook = self .layer .register_forward_pre_hook (layer_forward_hook )
359- else :
360- hook = self .layer .register_forward_hook (layer_forward_hook )
439+
440+ layers = self .layer
441+ if not isinstance (layers , list ):
442+ layers = [self .layer ]
443+
444+ for layer_idx , layer in enumerate (layers ):
445+ hook = None
446+ # TODO:
447+ # Allow multiple attribute_to_layer_input flags for
448+ # each layer, i.e. attribute_to_layer_input[layer_idx]
449+ if attribute_to_layer_input :
450+ hook = layer .register_forward_pre_hook (
451+ functools .partial (
452+ layer_forward_hook , layer_idx = layer_idx
453+ )
454+ )
455+ else :
456+ hook = layer .register_forward_hook (
457+ functools .partial (
458+ layer_forward_hook , layer_idx = layer_idx
459+ )
460+ )
461+
462+ hooks .append (hook )
361463
362464 output = _run_forward (
363465 self .forward_func , tuple (), target_ind , additional_forward_args
364466 )
365467 finally :
366- if hook is not None :
367- hook .remove ()
468+ for hook in hooks :
469+ if hook is not None :
470+ hook .remove ()
368471
369472 assert output [0 ].numel () == 1 , (
370473 "Target not provided when necessary, cannot"
@@ -381,6 +484,7 @@ def layer_forward_hook(module, hook_inputs, hook_outputs=None):
381484 if additional_forward_args is not None
382485 else inps
383486 )
487+
384488 attributions = self .ig .attribute .__wrapped__ ( # type: ignore
385489 self .ig , # self
386490 inputs_layer ,
@@ -393,6 +497,16 @@ def layer_forward_hook(module, hook_inputs, hook_outputs=None):
393497 return_convergence_delta = False ,
394498 )
395499
500+ # handle multiple outputs
501+ output : List [Tuple [Tensor , ...]] = [
502+ tuple (
503+ attributions [
504+ int (num_outputs_cumsum [i ]) : int (num_outputs_cumsum [i + 1 ])
505+ ]
506+ )
507+ for i in range (len (num_outputs ))
508+ ]
509+
396510 if return_convergence_delta :
397511 start_point , end_point = baselines , inps
398512 # computes approximation error based on the completeness axiom
@@ -403,8 +517,8 @@ def layer_forward_hook(module, hook_inputs, hook_outputs=None):
403517 additional_forward_args = additional_forward_args ,
404518 target = target ,
405519 )
406- return _format_output ( len ( attributions ) > 1 , attributions ), delta
407- return _format_output ( len ( attributions ) > 1 , attributions )
520+ return _format_outputs ( isinstance ( self . layer , list ), output ), delta
521+ return _format_outputs ( isinstance ( self . layer , list ), output )
408522
409523 def has_convergence_delta (self ) -> bool :
410524 return True
0 commit comments