Skip to content

Commit 3ab8ccb

Browse files
add backup for metrics.csv
1 parent f6c0c16 commit 3ab8ccb

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

torchmdnet/scripts/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torchmdnet.models import output_modules
2222
from torchmdnet.models.model import create_prior_models
2323
from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping, dtype_mapping
24-
from torchmdnet.utils import LoadFromFile, LoadFromCheckpoint, save_argparse, number
24+
from torchmdnet.utils import LoadFromFile, LoadFromCheckpoint, save_argparse, number, check_logs
2525
from lightning_utilities.core.rank_zero import rank_zero_warn
2626

2727

@@ -178,8 +178,9 @@ def main():
178178
auto_insert_metric_name=False,
179179
)
180180
early_stopping = EarlyStopping(val_loss_name, patience=args.early_stopping_patience)
181-
182181
csv_logger = CSVLogger(args.log_dir, name="", version="")
182+
check_logs(csv_logger)
183+
183184
_logger = [csv_logger]
184185
if args.wandb_use:
185186
wandb_logger = WandbLogger(

torchmdnet/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,3 +397,13 @@ def wrapped_init(self, *args, **kwargs):
397397

398398
cls.__init__ = wrapped_init
399399
return cls
400+
401+
def check_logs(csvlogger):
402+
import os
403+
import time
404+
metr_file_path = csvlogger.experiment.metrics_file_path
405+
if os.path.exists(metr_file_path):
406+
# we make a backup of the metrics file (rename)
407+
bckp_date = f'{time.strftime("%Y%m%d")}-{time.strftime("%H%M%S")}'
408+
os.rename(metr_file_path, metr_file_path.replace(".csv", f"_{bckp_date}.csv"))
409+
return

0 commit comments

Comments
 (0)