@@ -102,6 +102,7 @@ def __init__(
102102 checkpoints_load_func : Callable = _load_flexible_state_dict ,
103103 loss_fn : Optional [Union [Module , Callable ]] = None ,
104104 batch_size : Union [int , None ] = 1 ,
105+ test_loss_fn : Optional [Union [Module , Callable ]] = None ,
105106 ) -> None :
106107 r"""
107108 Args:
@@ -152,6 +153,19 @@ def __init__(
152153 `train_dataset` is a Dataset. If `train_dataset`
153154 is a DataLoader, then `batch_size` is ignored as an argument.
154155 Default: 1
156+ test_loss_fn (Callable, optional): In some cases, one may want to use a
157+ separate loss functions for training examples, i.e. those in
158+ `train_dataset`, and for test examples, i.e. those
159+ represented by the `inputs` and `targets` arguments to the
160+ `influence` method. For example, if one wants to calculate the
161+ influence score of a training example on a test example's
162+ prediction for a fixed class, `test_loss_fn` could map from the
163+ logits for all classes to the logits for a fixed class.
164+ `test_loss_fn` needs satisfy the same constraints as `loss_fn`.
165+ If not provided, the loss function for test examples is assumed to
166+ be the same as the loss function for training examples, i.e.
167+ `loss_fn`.
168+ Default: None
155169 """
156170
157171 self .model = model
@@ -167,6 +181,8 @@ def __init__(
167181
168182 self .checkpoints_load_func = checkpoints_load_func
169183 self .loss_fn = loss_fn
184+ # If test_loss_fn not provided, it's assumed to be same as loss_fn
185+ self .test_loss_fn = loss_fn if test_loss_fn is None else test_loss_fn
170186 self .batch_size = batch_size
171187
172188 if not isinstance (train_dataset , DataLoader ):
@@ -496,6 +512,7 @@ def __init__(
496512 layers : Optional [List [str ]] = None ,
497513 loss_fn : Optional [Union [Module , Callable ]] = None ,
498514 batch_size : Union [int , None ] = 1 ,
515+ test_loss_fn : Optional [Union [Module , Callable ]] = None ,
499516 sample_wise_grads_per_batch : bool = False ,
500517 ) -> None :
501518 r"""
@@ -568,6 +585,24 @@ def __init__(
568585 `train_dataset` is a Dataset. If `train_dataset`
569586 is a DataLoader, then `batch_size` is ignored as an argument.
570587 Default: 1
588+ test_loss_fn (Callable, optional): In some cases, one may want to use a
589+ separate loss functions for training examples, i.e. those in
590+ `train_dataset`, and for test examples, i.e. those
591+ represented by the `inputs` and `targets` arguments to the
592+ `influence` method. For example, if one wants to calculate the
593+ influence score of a training example on a test example's
594+ prediction for a fixed class, `test_loss_fn` could map from the
595+ logits for all classes to the logits for a fixed class.
596+ `test_loss_fn` needs satisfy the same constraints as `loss_fn`.
597+ Thus, the same checks that we apply to `loss_fn` are also applied
598+ to `test_loss_fn`, if the latter is provided. Note that the
599+ constraints on both `loss_fn` and `test_loss_fn` both depend on
600+ `sample_wise_grads_per_batch`. This means `loss_fn` and
601+ `test_loss_fn` must either both be "per-example" loss functions,
602+ or both be "reduction" loss functions. If not provided, the loss
603+ function for test examples is assumed to be the same as the loss
604+ function for training examples, i.e. `loss_fn`.
605+ Default: None
571606 sample_wise_grads_per_batch (bool, optional): PyTorch's native gradient
572607 computations w.r.t. model parameters aggregates the results for a
573608 batch and does not allow to access sample-wise gradients w.r.t.
@@ -597,51 +632,20 @@ def __init__(
597632 checkpoints_load_func ,
598633 loss_fn ,
599634 batch_size ,
635+ test_loss_fn ,
600636 )
601637
602638 self .sample_wise_grads_per_batch = sample_wise_grads_per_batch
603639
604- # If we are able to access the reduction used by `loss_fn`, we check whether
605- # the reduction is compatible with `sample_wise_grads_per_batch`
606- if isinstance (loss_fn , Module ) and hasattr (
607- loss_fn , "reduction"
608- ): # TODO: allow loss_fn to be Callable
609- if self .sample_wise_grads_per_batch :
610- assert loss_fn .reduction in ["sum" , "mean" ], (
611- 'reduction for `loss_fn` must be "sum" or "mean" when '
612- "`sample_wise_grads_per_batch` is True"
613- )
614- self .reduction_type = str (loss_fn .reduction )
615- else :
616- assert loss_fn .reduction == "none" , (
617- 'reduction for `loss_fn` must be "none" when '
618- "`sample_wise_grads_per_batch` is False"
619- )
640+ # check `loss_fn`
641+ self .reduction_type = _check_tracincp_loss_fn (self , loss_fn , "loss_fn" )
642+ # check `test_loss_fn` if it was provided
643+ if not (test_loss_fn is None ):
644+ self .test_reduction_type = _check_tracincp_loss_fn (
645+ self , test_loss_fn , "test_loss_fn"
646+ )
620647 else :
621- # if we are unable to access the reduction used by `loss_fn`, we warn
622- # the user about the assumptions we are making regarding the reduction
623- # used by `loss_fn`
624- if self .sample_wise_grads_per_batch :
625- warnings .warn (
626- 'Since `loss_fn` has no "reduction" attribute, and '
627- "`sample_wise_grads_per_batch` is True, the implementation assumes "
628- 'that `loss_fn` is a "reduction" loss function that reduces the '
629- "per-example losses by taking their *sum*. If `loss_fn` "
630- "instead reduces the per-example losses by taking their mean, "
631- 'please set the reduction attribute of `loss_fn` to "mean", i.e. '
632- '`loss_fn.reduction = "mean"`. Note that if '
633- "`sample_wise_grads_per_batch` is True, the implementation "
634- "assumes the reduction is either a sum or mean reduction."
635- )
636- self .reduction_type = "sum"
637- else :
638- warnings .warn (
639- 'Since `loss_fn` has no "reduction" attribute, and '
640- "`sample_wise_grads_per_batch` is False, the implementation "
641- 'assumes that `loss_fn` is a "per-example" loss function (see '
642- "documentation for `loss_fn` for details). Please ensure that "
643- "this is the case."
644- )
648+ self .test_reduction_type = self .reduction_type
645649
646650 r"""
647651 TODO: Either restore model state after done (would have to place functionality
@@ -801,11 +805,15 @@ def get_checkpoint_contribution(checkpoint):
801805 input_jacobians = self ._basic_computation_tracincp (
802806 inputs ,
803807 targets ,
808+ self .test_loss_fn ,
809+ self .test_reduction_type ,
804810 )
805811 return (
806812 _gradient_dot_product (
807813 input_jacobians ,
808- self ._basic_computation_tracincp (batch [0 :- 1 ], batch [- 1 ]),
814+ self ._basic_computation_tracincp (
815+ batch [0 :- 1 ], batch [- 1 ], self .loss_fn , self .reduction_type
816+ ),
809817 )
810818 * learning_rate
811819 )
@@ -1053,7 +1061,10 @@ def get_checkpoint_contribution(checkpoint):
10531061 for batch in _inputs_dataset :
10541062
10551063 layer_jacobians = self ._basic_computation_tracincp (
1056- batch [0 :- 1 ], batch [- 1 ]
1064+ batch [0 :- 1 ],
1065+ batch [- 1 ],
1066+ self .loss_fn ,
1067+ self .reduction_type ,
10571068 )
10581069
10591070 # Note that all variables in this function are for an entire batch.
@@ -1182,11 +1193,14 @@ def _basic_computation_tracincp(
11821193 self ,
11831194 inputs : Tuple [Any , ...],
11841195 targets : Optional [Tensor ] = None ,
1196+ loss_fn : Optional [Union [Module , Callable ]] = None ,
1197+ reduction_type : Optional [str ] = None ,
11851198 ) -> Tuple [Tensor , ...]:
11861199 """
11871200 For instances of TracInCP, computation of influence scores or self influence
11881201 scores repeatedly calls this function for different checkpoints
1189- and batches.
1202+ and batches. In particular, this function computes the jacobian of a loss
1203+ function w.r.t. parameters in the `layers` initialization argument.
11901204
11911205 Args:
11921206
@@ -1196,20 +1210,84 @@ def _basic_computation_tracincp(
11961210 that `model(*inputs)` produces the predictions for the batch.
11971211 targets (tensor or None): If computing influence scores on a loss function,
11981212 these are the labels corresponding to the batch `inputs`.
1213+ Default: none
1214+ loss_fn (Callable, optional): The loss function to use when computing the
1215+ jacobian.
1216+ reduction_type (str, optional): The reduction type of `loss_fn`. This
1217+ argument is only used if `sample_wise_grads_per_batch` was true in
1218+ initialization.
11991219 """
12001220 if self .sample_wise_grads_per_batch :
12011221 return _compute_jacobian_wrt_params_with_sample_wise_trick (
12021222 self .model ,
12031223 inputs ,
12041224 targets ,
1205- self . loss_fn ,
1206- self . reduction_type ,
1225+ loss_fn ,
1226+ reduction_type ,
12071227 self .layer_modules ,
12081228 )
12091229 return _compute_jacobian_wrt_params (
12101230 self .model ,
12111231 inputs ,
12121232 targets ,
1213- self . loss_fn ,
1233+ loss_fn ,
12141234 self .layer_modules ,
12151235 )
1236+
1237+
1238+ def _check_tracincp_loss_fn (
1239+ influence_instance : TracInCP ,
1240+ loss_fn : Optional [Union [Module , Callable ]],
1241+ loss_fn_name : str ,
1242+ ) -> Optional [str ]:
1243+ """
1244+ This checks whether `loss_fn` satisfies the requirements described in
1245+ `TracInCP.__init__`. It returns the reduction type of the loss_fn, which will not
1246+ be None, only if `influence_instance.sample_wise_grads_per_batch` is True.
1247+ """
1248+
1249+ reduction_type = None
1250+
1251+ # If we are able to access the reduction used by `loss_fn`, we check whether
1252+ # the reduction is compatible with `sample_wise_grads_per_batch`
1253+ if isinstance (loss_fn , Module ) and hasattr (
1254+ loss_fn , "reduction"
1255+ ): # TODO: allow loss_fn to be Callable
1256+ if influence_instance .sample_wise_grads_per_batch :
1257+ assert loss_fn .reduction in ["sum" , "mean" ], (
1258+ 'reduction for `loss_fn` must be "sum" or "mean" when '
1259+ "`sample_wise_grads_per_batch` is True"
1260+ )
1261+ reduction_type = str (loss_fn .reduction )
1262+ else :
1263+ assert loss_fn .reduction == "none" , (
1264+ 'reduction for `loss_fn` must be "none" when '
1265+ "`sample_wise_grads_per_batch` is False"
1266+ )
1267+ else :
1268+ # if we are unable to access the reduction used by `loss_fn`, we warn
1269+ # the user about the assumptions we are making regarding the reduction
1270+ # used by `loss_fn`
1271+ if influence_instance .sample_wise_grads_per_batch :
1272+ warnings .warn (
1273+ f"Since `{ loss_fn_name } `` has no 'reduction' attribute, and "
1274+ "`sample_wise_grads_per_batch` is True, the implementation assumes "
1275+ f"that `{ loss_fn_name } ` is a 'reduction' loss function that reduces "
1276+ f"the per-example losses by taking their *sum*. If `{ loss_fn_name } ` "
1277+ "instead reduces the per-example losses by taking their mean, "
1278+ f'please set the reduction attribute of `{ loss_fn_name } ` to "mean", '
1279+ 'i.e. `{loss_fn_name}.reduction = "mean"`. Note that if '
1280+ "`sample_wise_grads_per_batch` is True, the implementation "
1281+ "assumes the reduction is either a sum or mean reduction."
1282+ )
1283+ reduction_type = "sum"
1284+ else :
1285+ warnings .warn (
1286+ f'Since `{ loss_fn_name } ` has no "reduction" attribute, and '
1287+ "`sample_wise_grads_per_batch` is False, the implementation "
1288+ f'assumes that `{ loss_fn_name } ` is a "per-example" loss function (see '
1289+ f"documentation for `{ loss_fn_name } ` for details). Please ensure "
1290+ "that this is the case."
1291+ )
1292+
1293+ return reduction_type
0 commit comments