diff --git a/train.py b/train.py index 131260dca4..9c180acdd8 100755 --- a/train.py +++ b/train.py @@ -63,6 +63,12 @@ except ImportError as e: has_functorch = False +try: + from sklearn.metrics import precision_score, recall_score, f1_score + has_sklearn = True +except ImportError: + has_sklearn = False + has_compile = hasattr(torch, 'compile') @@ -400,6 +406,10 @@ help='wandb tags') group.add_argument('--wandb-resume-id', default='', type=str, metavar='ID', help='If resuming a run, the id of the run in wandb') +group.add_argument('--metrics-avg', type=str, default=None, + choices=['micro', 'macro', 'weighted'], + help='Enable precision, recall, F1-score calculation and specify the averaging method. ' + 'Requires scikit-learn. (default: None)') # NaFlex scheduled loader arguments group.add_argument('--naflex-loader', action='store_true', default=False, @@ -1318,6 +1328,15 @@ def validate( top1_m = utils.AverageMeter() top5_m = utils.AverageMeter() + if args.metrics_avg: + if not has_sklearn: + _logger.warning( + "scikit-learn not installed, disabling extra metrics. Please install with 'pip install scikit-learn'.") + args.metrics_avg = None + else: + all_preds = [] + all_targets = [] + model.eval() end = time.time() @@ -1345,6 +1364,10 @@ def validate( loss = loss_fn(output, target) acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) + if args.metrics_avg: + all_preds.append(torch.argmax(output, dim=1).cpu()) + all_targets.append(target.cpu()) + if args.distributed: reduced_loss = utils.reduce_tensor(loss.data, args.world_size) acc1 = utils.reduce_tensor(acc1, args.world_size) @@ -1374,7 +1397,31 @@ def validate( f'Acc@5: {top5_m.val:>7.3f} ({top5_m.avg:>7.3f})' ) - metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) + metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) + + if args.metrics_avg: + all_preds = torch.cat(all_preds) + all_targets = torch.cat(all_targets) + + if args.distributed: + # Gather list of tensors from all processes + pred_list = [torch.zeros_like(all_preds) for _ in range(args.world_size)] + target_list = [torch.zeros_like(all_targets) for _ in range(args.world_size)] + torch.distributed.all_gather(pred_list, all_preds) + torch.distributed.all_gather(target_list, all_targets) + + if utils.is_primary(args): + all_preds = torch.cat(pred_list) + all_targets = torch.cat(target_list) + + if utils.is_primary(args): + precision = precision_score(all_targets.numpy(), all_preds.numpy(), average=args.metrics_avg, zero_division=0) + recall = recall_score(all_targets.numpy(), all_preds.numpy(), average=args.metrics_avg, zero_division=0) + f1 = f1_score(all_targets.numpy(), all_preds.numpy(), average=args.metrics_avg, zero_division=0) + + metrics[f'{args.metrics_avg}_precision'] = round(precision, 4) + metrics[f'{args.metrics_avg}_recall'] = round(recall, 4) + metrics[f'{args.metrics_avg}_f1_score'] = round(f1, 4) return metrics