Skip to content

Commit ee140d2

Browse files
Fulton Wangfacebook-github-bot
authored andcommitted
add test loss
Summary: - For all `TracInCPBase` implementations, this adds an additional `test_loss_fn` initialization argument, which is the loss function to apply to test examples when computing the influence of a training example on a test example. With this change,the influence score is a sum over terms for each checkpoint, where each term is the gradient of `loss_fn` for a given training example, multiplied with the gradient of `test_loss_fn` for a given test example. Before, `test_loss_fn` was assumed to be the same as `loss_fn`. - checks regarding the reduction type of both `loss_fn` and `test_loss_fn` are now handled by helper functions `_check_tracincp_loss_fn` and `_check_tracincp_fast_loss_fn`. - documentation is updated. one detail: for `TracInCP`, we assume that `sample_wise_grads_per_batch` is applied to both `loss_fn` and `test_loss_fn` (if provided), and this is mentioned in the documentation. - `test_tracin_regression.test_tracin_regression` is slightly modified - `DataInfluenceConstructor` now can explicitly pass in the same loss function for both `loss_fn` and `test_loss_fn` (done when `duplicate_loss_fn=True`). Doing so would have the same effect as not passing in `test_loss_fn`, so the original tests are also applied to the case when `duplicate_loss_fn=True`, as the expected behavior should be the same as before. - a new test, `test_tracin_regression.test_tracin_constant_test_loss_fn` is added. For all implementations of `TracInCPBase`, it checks that if `test_loss_fn` is a constant loss function, the influence scores are all 0's. This should be the case, because if `test_loss_fn` is constant, its gradients would all be 0's, so that training examples have 0 influence on test examples. Reviewed By: cyrjano Differential Revision: D41202866 fbshipit-source-id: d22fcf873d9d63e5749ae94bee5ed7e868de80d3
1 parent ada8c0d commit ee140d2

File tree

4 files changed

+356
-80
lines changed

4 files changed

+356
-80
lines changed

captum/influence/_core/tracincp.py

Lines changed: 124 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
checkpoints_load_func: Callable = _load_flexible_state_dict,
103103
loss_fn: Optional[Union[Module, Callable]] = None,
104104
batch_size: Union[int, None] = 1,
105+
test_loss_fn: Optional[Union[Module, Callable]] = None,
105106
) -> None:
106107
r"""
107108
Args:
@@ -152,6 +153,19 @@ def __init__(
152153
`train_dataset` is a Dataset. If `train_dataset`
153154
is a DataLoader, then `batch_size` is ignored as an argument.
154155
Default: 1
156+
test_loss_fn (Callable, optional): In some cases, one may want to use a
157+
separate loss functions for training examples, i.e. those in
158+
`train_dataset`, and for test examples, i.e. those
159+
represented by the `inputs` and `targets` arguments to the
160+
`influence` method. For example, if one wants to calculate the
161+
influence score of a training example on a test example's
162+
prediction for a fixed class, `test_loss_fn` could map from the
163+
logits for all classes to the logits for a fixed class.
164+
`test_loss_fn` needs satisfy the same constraints as `loss_fn`.
165+
If not provided, the loss function for test examples is assumed to
166+
be the same as the loss function for training examples, i.e.
167+
`loss_fn`.
168+
Default: None
155169
"""
156170

157171
self.model = model
@@ -167,6 +181,8 @@ def __init__(
167181

168182
self.checkpoints_load_func = checkpoints_load_func
169183
self.loss_fn = loss_fn
184+
# If test_loss_fn not provided, it's assumed to be same as loss_fn
185+
self.test_loss_fn = loss_fn if test_loss_fn is None else test_loss_fn
170186
self.batch_size = batch_size
171187

172188
if not isinstance(train_dataset, DataLoader):
@@ -496,6 +512,7 @@ def __init__(
496512
layers: Optional[List[str]] = None,
497513
loss_fn: Optional[Union[Module, Callable]] = None,
498514
batch_size: Union[int, None] = 1,
515+
test_loss_fn: Optional[Union[Module, Callable]] = None,
499516
sample_wise_grads_per_batch: bool = False,
500517
) -> None:
501518
r"""
@@ -568,6 +585,24 @@ def __init__(
568585
`train_dataset` is a Dataset. If `train_dataset`
569586
is a DataLoader, then `batch_size` is ignored as an argument.
570587
Default: 1
588+
test_loss_fn (Callable, optional): In some cases, one may want to use a
589+
separate loss functions for training examples, i.e. those in
590+
`train_dataset`, and for test examples, i.e. those
591+
represented by the `inputs` and `targets` arguments to the
592+
`influence` method. For example, if one wants to calculate the
593+
influence score of a training example on a test example's
594+
prediction for a fixed class, `test_loss_fn` could map from the
595+
logits for all classes to the logits for a fixed class.
596+
`test_loss_fn` needs satisfy the same constraints as `loss_fn`.
597+
Thus, the same checks that we apply to `loss_fn` are also applied
598+
to `test_loss_fn`, if the latter is provided. Note that the
599+
constraints on both `loss_fn` and `test_loss_fn` both depend on
600+
`sample_wise_grads_per_batch`. This means `loss_fn` and
601+
`test_loss_fn` must either both be "per-example" loss functions,
602+
or both be "reduction" loss functions. If not provided, the loss
603+
function for test examples is assumed to be the same as the loss
604+
function for training examples, i.e. `loss_fn`.
605+
Default: None
571606
sample_wise_grads_per_batch (bool, optional): PyTorch's native gradient
572607
computations w.r.t. model parameters aggregates the results for a
573608
batch and does not allow to access sample-wise gradients w.r.t.
@@ -597,51 +632,20 @@ def __init__(
597632
checkpoints_load_func,
598633
loss_fn,
599634
batch_size,
635+
test_loss_fn,
600636
)
601637

602638
self.sample_wise_grads_per_batch = sample_wise_grads_per_batch
603639

604-
# If we are able to access the reduction used by `loss_fn`, we check whether
605-
# the reduction is compatible with `sample_wise_grads_per_batch`
606-
if isinstance(loss_fn, Module) and hasattr(
607-
loss_fn, "reduction"
608-
): # TODO: allow loss_fn to be Callable
609-
if self.sample_wise_grads_per_batch:
610-
assert loss_fn.reduction in ["sum", "mean"], (
611-
'reduction for `loss_fn` must be "sum" or "mean" when '
612-
"`sample_wise_grads_per_batch` is True"
613-
)
614-
self.reduction_type = str(loss_fn.reduction)
615-
else:
616-
assert loss_fn.reduction == "none", (
617-
'reduction for `loss_fn` must be "none" when '
618-
"`sample_wise_grads_per_batch` is False"
619-
)
640+
# check `loss_fn`
641+
self.reduction_type = _check_tracincp_loss_fn(self, loss_fn, "loss_fn")
642+
# check `test_loss_fn` if it was provided
643+
if not (test_loss_fn is None):
644+
self.test_reduction_type = _check_tracincp_loss_fn(
645+
self, test_loss_fn, "test_loss_fn"
646+
)
620647
else:
621-
# if we are unable to access the reduction used by `loss_fn`, we warn
622-
# the user about the assumptions we are making regarding the reduction
623-
# used by `loss_fn`
624-
if self.sample_wise_grads_per_batch:
625-
warnings.warn(
626-
'Since `loss_fn` has no "reduction" attribute, and '
627-
"`sample_wise_grads_per_batch` is True, the implementation assumes "
628-
'that `loss_fn` is a "reduction" loss function that reduces the '
629-
"per-example losses by taking their *sum*. If `loss_fn` "
630-
"instead reduces the per-example losses by taking their mean, "
631-
'please set the reduction attribute of `loss_fn` to "mean", i.e. '
632-
'`loss_fn.reduction = "mean"`. Note that if '
633-
"`sample_wise_grads_per_batch` is True, the implementation "
634-
"assumes the reduction is either a sum or mean reduction."
635-
)
636-
self.reduction_type = "sum"
637-
else:
638-
warnings.warn(
639-
'Since `loss_fn` has no "reduction" attribute, and '
640-
"`sample_wise_grads_per_batch` is False, the implementation "
641-
'assumes that `loss_fn` is a "per-example" loss function (see '
642-
"documentation for `loss_fn` for details). Please ensure that "
643-
"this is the case."
644-
)
648+
self.test_reduction_type = self.reduction_type
645649

646650
r"""
647651
TODO: Either restore model state after done (would have to place functionality
@@ -801,11 +805,15 @@ def get_checkpoint_contribution(checkpoint):
801805
input_jacobians = self._basic_computation_tracincp(
802806
inputs,
803807
targets,
808+
self.test_loss_fn,
809+
self.test_reduction_type,
804810
)
805811
return (
806812
_gradient_dot_product(
807813
input_jacobians,
808-
self._basic_computation_tracincp(batch[0:-1], batch[-1]),
814+
self._basic_computation_tracincp(
815+
batch[0:-1], batch[-1], self.loss_fn, self.reduction_type
816+
),
809817
)
810818
* learning_rate
811819
)
@@ -1053,7 +1061,10 @@ def get_checkpoint_contribution(checkpoint):
10531061
for batch in _inputs_dataset:
10541062

10551063
layer_jacobians = self._basic_computation_tracincp(
1056-
batch[0:-1], batch[-1]
1064+
batch[0:-1],
1065+
batch[-1],
1066+
self.loss_fn,
1067+
self.reduction_type,
10571068
)
10581069

10591070
# Note that all variables in this function are for an entire batch.
@@ -1182,11 +1193,14 @@ def _basic_computation_tracincp(
11821193
self,
11831194
inputs: Tuple[Any, ...],
11841195
targets: Optional[Tensor] = None,
1196+
loss_fn: Optional[Union[Module, Callable]] = None,
1197+
reduction_type: Optional[str] = None,
11851198
) -> Tuple[Tensor, ...]:
11861199
"""
11871200
For instances of TracInCP, computation of influence scores or self influence
11881201
scores repeatedly calls this function for different checkpoints
1189-
and batches.
1202+
and batches. In particular, this function computes the jacobian of a loss
1203+
function w.r.t. parameters in the `layers` initialization argument.
11901204
11911205
Args:
11921206
@@ -1196,20 +1210,84 @@ def _basic_computation_tracincp(
11961210
that `model(*inputs)` produces the predictions for the batch.
11971211
targets (tensor or None): If computing influence scores on a loss function,
11981212
these are the labels corresponding to the batch `inputs`.
1213+
Default: none
1214+
loss_fn (Callable, optional): The loss function to use when computing the
1215+
jacobian.
1216+
reduction_type (str, optional): The reduction type of `loss_fn`. This
1217+
argument is only used if `sample_wise_grads_per_batch` was true in
1218+
initialization.
11991219
"""
12001220
if self.sample_wise_grads_per_batch:
12011221
return _compute_jacobian_wrt_params_with_sample_wise_trick(
12021222
self.model,
12031223
inputs,
12041224
targets,
1205-
self.loss_fn,
1206-
self.reduction_type,
1225+
loss_fn,
1226+
reduction_type,
12071227
self.layer_modules,
12081228
)
12091229
return _compute_jacobian_wrt_params(
12101230
self.model,
12111231
inputs,
12121232
targets,
1213-
self.loss_fn,
1233+
loss_fn,
12141234
self.layer_modules,
12151235
)
1236+
1237+
1238+
def _check_tracincp_loss_fn(
1239+
influence_instance: TracInCP,
1240+
loss_fn: Optional[Union[Module, Callable]],
1241+
loss_fn_name: str,
1242+
) -> Optional[str]:
1243+
"""
1244+
This checks whether `loss_fn` satisfies the requirements described in
1245+
`TracInCP.__init__`. It returns the reduction type of the loss_fn, which will not
1246+
be None, only if `influence_instance.sample_wise_grads_per_batch` is True.
1247+
"""
1248+
1249+
reduction_type = None
1250+
1251+
# If we are able to access the reduction used by `loss_fn`, we check whether
1252+
# the reduction is compatible with `sample_wise_grads_per_batch`
1253+
if isinstance(loss_fn, Module) and hasattr(
1254+
loss_fn, "reduction"
1255+
): # TODO: allow loss_fn to be Callable
1256+
if influence_instance.sample_wise_grads_per_batch:
1257+
assert loss_fn.reduction in ["sum", "mean"], (
1258+
'reduction for `loss_fn` must be "sum" or "mean" when '
1259+
"`sample_wise_grads_per_batch` is True"
1260+
)
1261+
reduction_type = str(loss_fn.reduction)
1262+
else:
1263+
assert loss_fn.reduction == "none", (
1264+
'reduction for `loss_fn` must be "none" when '
1265+
"`sample_wise_grads_per_batch` is False"
1266+
)
1267+
else:
1268+
# if we are unable to access the reduction used by `loss_fn`, we warn
1269+
# the user about the assumptions we are making regarding the reduction
1270+
# used by `loss_fn`
1271+
if influence_instance.sample_wise_grads_per_batch:
1272+
warnings.warn(
1273+
f"Since `{loss_fn_name}`` has no 'reduction' attribute, and "
1274+
"`sample_wise_grads_per_batch` is True, the implementation assumes "
1275+
f"that `{loss_fn_name}` is a 'reduction' loss function that reduces "
1276+
f"the per-example losses by taking their *sum*. If `{loss_fn_name}` "
1277+
"instead reduces the per-example losses by taking their mean, "
1278+
f'please set the reduction attribute of `{loss_fn_name}` to "mean", '
1279+
'i.e. `{loss_fn_name}.reduction = "mean"`. Note that if '
1280+
"`sample_wise_grads_per_batch` is True, the implementation "
1281+
"assumes the reduction is either a sum or mean reduction."
1282+
)
1283+
reduction_type = "sum"
1284+
else:
1285+
warnings.warn(
1286+
f'Since `{loss_fn_name}` has no "reduction" attribute, and '
1287+
"`sample_wise_grads_per_batch` is False, the implementation "
1288+
f'assumes that `{loss_fn_name}` is a "per-example" loss function (see '
1289+
f"documentation for `{loss_fn_name}` for details). Please ensure "
1290+
"that this is the case."
1291+
)
1292+
1293+
return reduction_type

0 commit comments

Comments
 (0)