Skip to content

Commit 322c28a

Browse files
Fulton Wangfacebook-github-bot
authored andcommitted
add compute_intermediate_quantities to TracInCP
Summary: This diff adds the `compute_intermediate_quantities` method to `TracInCP`, which returns influence embeddings such that the influence of one example on another is the dot-product of their respective influence embeddings. In the case of `TracInCP`, its influence embeddings are simply the parameter-gradients for an example, concatenated over different checkpoints. There is also an `aggregate` option that if True, returns not the influence embeddings of each example in the given dataset, but instead their *sum*. This is useful for the validation diff workflow (which is the next diff in the stack), where we want to calculate the influence of a given training example on an entire validation dataset. This can be accomplished by taking the dot-product of the training example's influence embedding with the *sum* of the influence embeddings over the validation dataset (i.e. with `aggregate=True`) For tests, the tests currently used for `TracInCPFastRandProj.compute_intermediate_quantities` (`test_tracin_intermediate_quantities.test_tracin_intermediate_quantities_api`, `test_tracin_intermediate_quantities.test_tracin_intermediate_quantities_consistent`) are applied to `TracInCP.compute_intermediate_quantities`. In addition, `test_tracin_intermediate_quantities.test_tracin_intermediate_quantities_aggregate` is added to test the `aggregate=True` option, checking that with `aggregate=True`, the returned influence embedding is indeed the sum of the influence embeddings for the given dataset. Reviewed By: cyrjano Differential Revision: D40688327 fbshipit-source-id: dd882a59c4463ceb3f79132011975f9f890657d5
1 parent a7610be commit 322c28a

File tree

2 files changed

+197
-0
lines changed

2 files changed

+197
-0
lines changed

captum/influence/_core/tracincp.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,141 @@ def influence( # type: ignore[override]
780780
show_progress,
781781
)
782782

783+
def _sum_jacobians(self, inputs_dataset: DataLoader):
784+
"""
785+
sums the jacobians of all examples in `inputs_dataset`. result is of the
786+
same format as layer_jacobians, but the batch dimension has size 1
787+
"""
788+
inputs_dataset_iter = iter(inputs_dataset)
789+
790+
inputs_batch = next(inputs_dataset_iter)
791+
792+
def get_batch_contribution(inputs_batch):
793+
_input_jacobians = self._basic_computation_tracincp(
794+
inputs_batch[0:-1],
795+
inputs_batch[-1],
796+
)
797+
798+
return tuple(
799+
torch.sum(jacobian, dim=0).unsqueeze(0) for jacobian in _input_jacobians
800+
)
801+
802+
inputs_jacobians = get_batch_contribution(inputs_batch)
803+
804+
for inputs_batch in inputs_dataset_iter:
805+
inputs_batch_jacobians = get_batch_contribution(inputs_batch)
806+
inputs_jacobians = tuple(
807+
[
808+
inputs_jacobian + inputs_batch_jacobian
809+
for (inputs_jacobian, inputs_batch_jacobian) in zip(
810+
inputs_jacobians, inputs_batch_jacobians
811+
)
812+
]
813+
)
814+
815+
return inputs_jacobians
816+
817+
def _concat_jacobians(self, inputs_dataset: DataLoader):
818+
all_inputs_batch_jacobians = [
819+
self._basic_computation_tracincp(
820+
inputs_batch[0:-1],
821+
inputs_batch[-1],
822+
)
823+
for inputs_batch in inputs_dataset
824+
]
825+
826+
return tuple(
827+
torch.cat(all_inputs_batch_jacobian, dim=0)
828+
for all_inputs_batch_jacobian in zip(*all_inputs_batch_jacobians)
829+
)
830+
831+
def compute_intermediate_quantities(
832+
self,
833+
inputs_dataset: Union[Tuple[Any, ...], DataLoader],
834+
aggregate: bool = False,
835+
) -> Tensor:
836+
"""
837+
Computes "embedding" vectors for all examples in a single batch, or a
838+
`Dataloader` that yields batches. These embedding vectors are constructed so
839+
that the influence score of a training example on a test example is simply the
840+
dot-product of their corresponding vectors. Allowing a `DataLoader`
841+
yielding batches to be passed in (as opposed to a single batch) gives the
842+
potential to improve efficiency, because we load each checkpoint only once in
843+
this method call. Thus if a `DataLoader` yielding batches is passed in, this
844+
reduces the total number of times each checkpoint is loaded for a dataset,
845+
compared to if a single batch is passed in. The reason we do not just increase
846+
the batch size is that for large models, large batches do not fit in memory.
847+
848+
If `aggregate is True, the *sum* of the vectors for all examples is returned,
849+
instead of the vectors for each example. This can be useful for computing the
850+
influence of a given training example on the total loss over a validation
851+
dataset, because due to properties of the dot-product, this influence is the
852+
dot-product of the training example's vector with the sum of the vectors in the
853+
validation dataset. Also, by doing the sum aggregation within this method as
854+
opposed to outside of it (by computing all vectors for the validation dataset,
855+
then taking the sum) allows memory usage to be reduced.
856+
857+
Args:
858+
inputs_dataset (Tuple, or DataLoader): Either a single tuple of any, or a
859+
`DataLoader`, where each batch yielded is a tuple of any. In
860+
either case, the tuple represents a single batch, where the last
861+
element is assumed to be the labels for the batch. That is,
862+
`model(*batch[0:-1])` produces the output for `model`, and
863+
and `batch[-1]` are the labels, if any. Here, `model` is model
864+
provided in initialization. This is the same assumption made for
865+
each batch yielded by training dataset `train_dataset`.
866+
867+
Returns:
868+
intermediate_quantities (Tensor): A tensor of dimension
869+
(N, D * C). Here, N is the total number of examples in
870+
`inputs_dataset` if `aggregate` is False, and 1, otherwise (so that
871+
a 2D tensor is always returned). C is the number of checkpoints
872+
passed as the `checkpoints` argument of `TracInCP.__init__`, and
873+
each row represents the vector for an example. Regarding D: Let I
874+
be the dimension of the output of the last fully-connected layer
875+
times the dimension of the input of the last fully-connected layer. If
876+
`self.projection_dim` is specified in initialization,
877+
D = min(I * C, `self.projection_dim` * C). Otherwise, D = I * C.
878+
In summary, if `self.projection_dim` is None, the dimension of each
879+
vector will be determined by the size of the input and output of
880+
the last fully-connected layer of `model`. Otherwise,
881+
`self.projection_dim` must be an int, and random projection will be
882+
performed to ensure that the vector is of dimension no more than
883+
`self.projection_dim` * C. `self.projection_dim` corresponds to
884+
the variable d in the top of page 15 of the TracIn paper:
885+
https://arxiv.org/pdf/2002.08484.pdf.
886+
"""
887+
# If `inputs_dataset` is not a `DataLoader`, turn it into one.
888+
inputs_dataset = _format_inputs_dataset(inputs_dataset)
889+
890+
def get_checkpoint_contribution(checkpoint):
891+
assert (
892+
checkpoint is not None
893+
), "None returned from `checkpoints`, cannot load."
894+
895+
learning_rate = self.checkpoints_load_func(self.model, checkpoint)
896+
# get jacobians as tuple of tensors
897+
if aggregate:
898+
inputs_jacobians = self._sum_jacobians(inputs_dataset)
899+
else:
900+
inputs_jacobians = self._concat_jacobians(inputs_dataset)
901+
# flatten into single tensor
902+
return learning_rate * torch.cat(
903+
[
904+
input_jacobian.flatten(start_dim=1)
905+
for input_jacobian in inputs_jacobians
906+
],
907+
dim=1,
908+
)
909+
910+
return torch.cat(
911+
[
912+
get_checkpoint_contribution(checkpoint)
913+
for checkpoint in self.checkpoints
914+
],
915+
dim=1,
916+
)
917+
783918
def _influence_batch_tracincp(
784919
self,
785920
inputs: Tuple[Any, ...],

tests/influence/_core/test_tracin_intermediate_quantities.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55

66
import torch.nn as nn
7+
from captum.influence._core.tracincp import TracInCP
78
from captum.influence._core.tracincp_fast_rand_proj import (
89
TracInCPFast,
910
TracInCPFastRandProj,
@@ -19,12 +20,68 @@
1920

2021

2122
class TestTracInIntermediateQuantities(BaseTest):
23+
@parameterized.expand(
24+
[
25+
(reduction, constructor, unpack_inputs)
26+
for unpack_inputs in [True, False]
27+
for (reduction, constructor) in [
28+
("none", DataInfluenceConstructor(TracInCP)),
29+
]
30+
],
31+
name_func=build_test_name_func(),
32+
)
33+
def test_tracin_intermediate_quantities_aggregate(
34+
self, reduction: str, tracin_constructor: Callable, unpack_inputs: bool
35+
) -> None:
36+
"""
37+
tests that calling `compute_intermediate_quantities` with `aggregate=True`
38+
does give the same result as calling it with `aggregate=False`, and then
39+
summing
40+
"""
41+
with tempfile.TemporaryDirectory() as tmpdir:
42+
(net, train_dataset,) = get_random_model_and_data(
43+
tmpdir,
44+
unpack_inputs,
45+
return_test_data=False,
46+
)
47+
48+
# create a dataloader that yields batches from the dataset
49+
train_dataset = DataLoader(train_dataset, batch_size=5)
50+
51+
# create tracin instance
52+
criterion = nn.MSELoss(reduction=reduction)
53+
batch_size = 5
54+
55+
tracin = tracin_constructor(
56+
net,
57+
train_dataset,
58+
tmpdir,
59+
batch_size,
60+
criterion,
61+
)
62+
63+
intermediate_quantities = tracin.compute_intermediate_quantities(
64+
train_dataset, aggregate=False
65+
)
66+
aggregated_intermediate_quantities = tracin.compute_intermediate_quantities(
67+
train_dataset, aggregate=True
68+
)
69+
70+
assertTensorAlmostEqual(
71+
self,
72+
torch.sum(intermediate_quantities, dim=0, keepdim=True),
73+
aggregated_intermediate_quantities,
74+
delta=1e-4, # due to numerical issues, we can't set this to 0.0
75+
mode="max",
76+
)
77+
2278
@parameterized.expand(
2379
[
2480
(reduction, constructor, unpack_inputs)
2581
for unpack_inputs in [True, False]
2682
for (reduction, constructor) in [
2783
("sum", DataInfluenceConstructor(TracInCPFastRandProj)),
84+
("none", DataInfluenceConstructor(TracInCP)),
2885
]
2986
],
3087
name_func=build_test_name_func(),
@@ -103,6 +160,11 @@ def test_tracin_intermediate_quantities_api(
103160
DataInfluenceConstructor(TracInCPFast),
104161
DataInfluenceConstructor(TracInCPFastRandProj),
105162
),
163+
(
164+
"none",
165+
DataInfluenceConstructor(TracInCP),
166+
DataInfluenceConstructor(TracInCP),
167+
),
106168
]
107169
],
108170
name_func=build_test_name_func(),

0 commit comments

Comments
 (0)