@@ -444,7 +444,7 @@ def _check_loss_fn(
444444 influence_instance : Union ["TracInCPBase" , "InfluenceFunctionBase" ],
445445 loss_fn : Optional [Union [Module , Callable ]],
446446 loss_fn_name : str ,
447- sample_wise_grads_per_batch : Optional [ bool ] = None ,
447+ sample_wise_grads_per_batch : bool = True ,
448448) -> str :
449449 """
450450 This checks whether `loss_fn` satisfies the requirements assumed of all
@@ -469,16 +469,13 @@ def _check_loss_fn(
469469 # attribute.
470470 if hasattr (loss_fn , "reduction" ):
471471 reduction = loss_fn .reduction # type: ignore
472- if sample_wise_grads_per_batch is None :
472+ if sample_wise_grads_per_batch :
473473 assert reduction in [
474474 "sum" ,
475475 "mean" ,
476- ], 'reduction for `loss_fn` must be "sum" or "mean"'
477- reduction_type = str (reduction )
478- elif sample_wise_grads_per_batch :
479- assert reduction in ["sum" , "mean" ], (
476+ ], (
480477 'reduction for `loss_fn` must be "sum" or "mean" when '
481- "`sample_wise_grads_per_batch` is True"
478+ "`sample_wise_grads_per_batch` is True (i.e. the default value) "
482479 )
483480 reduction_type = str (reduction )
484481 else :
@@ -490,18 +487,7 @@ def _check_loss_fn(
490487 # if we are unable to access the reduction used by `loss_fn`, we warn
491488 # the user about the assumptions we are making regarding the reduction
492489 # used by `loss_fn`
493- if sample_wise_grads_per_batch is None :
494- warnings .warn (
495- f'Since `{ loss_fn_name } ` has no "reduction" attribute, the '
496- f'implementation assumes that `{ loss_fn_name } ` is a "reduction" loss '
497- "function that reduces the per-example losses by taking their *sum*. "
498- f"If `{ loss_fn_name } ` instead reduces the per-example losses by "
499- f"taking their mean, please set the reduction attribute of "
500- f'`{ loss_fn_name } ` to "mean", i.e. '
501- f'`{ loss_fn_name } .reduction = "mean"`.'
502- )
503- reduction_type = "sum"
504- elif sample_wise_grads_per_batch :
490+ if sample_wise_grads_per_batch :
505491 warnings .warn (
506492 f"Since `{ loss_fn_name } `` has no 'reduction' attribute, and "
507493 "`sample_wise_grads_per_batch` is True, the implementation assumes "
0 commit comments