Skip to content

Commit 8374e96

Browse files
authored
Merge pull request #345 from AntonioMirarchi/update_csv_logger
Add backup for metrics.csv
2 parents f013b80 + e1bc7ef commit 8374e96

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

torchmdnet/scripts/train.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@
2121
from torchmdnet.models import output_modules
2222
from torchmdnet.models.model import create_prior_models
2323
from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping, dtype_mapping
24-
from torchmdnet.utils import LoadFromFile, LoadFromCheckpoint, save_argparse, number
24+
from torchmdnet.utils import (
25+
LoadFromFile,
26+
LoadFromCheckpoint,
27+
save_argparse,
28+
number,
29+
check_logs,
30+
)
2531
from lightning_utilities.core.rank_zero import rank_zero_warn
2632

2733

@@ -219,9 +225,11 @@ def main():
219225
args.early_stopping_monitor, patience=args.early_stopping_patience
220226
)
221227
callbacks.append(early_stopping)
222-
228+
229+
check_logs(args.log_dir)
223230
csv_logger = CSVLogger(args.log_dir, name="", version="")
224231
_logger = [csv_logger]
232+
225233
if args.wandb_use:
226234
wandb_logger = WandbLogger(
227235
project=args.wandb_project,

torchmdnet/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,3 +397,13 @@ def wrapped_init(self, *args, **kwargs):
397397

398398
cls.__init__ = wrapped_init
399399
return cls
400+
401+
def check_logs(log_dir):
402+
import os
403+
import time
404+
metr_file_path = os.path.join(log_dir, 'metrics.csv')
405+
if os.path.exists(metr_file_path):
406+
# we make a backup of the metrics file (rename)
407+
bckp_date = f'{time.strftime("%Y%m%d")}-{time.strftime("%H%M%S")}'
408+
os.rename(metr_file_path, metr_file_path.replace(".csv", f"_{bckp_date}.csv"))
409+
return

0 commit comments

Comments
 (0)