Skip to content

Commit 4e7cea8

Browse files
avoid error when multiple gpus are used
1 parent c5a05a5 commit 4e7cea8

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

torchmdnet/scripts/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,9 @@ def main():
184184
auto_insert_metric_name=False,
185185
)
186186
early_stopping = EarlyStopping(val_loss_name, patience=args.early_stopping_patience)
187-
187+
188+
check_logs(args.log_dir)
188189
csv_logger = CSVLogger(args.log_dir, name="", version="")
189-
check_logs(csv_logger)
190190
_logger = [csv_logger]
191191

192192
if args.wandb_use:

torchmdnet/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,10 +398,10 @@ def wrapped_init(self, *args, **kwargs):
398398
cls.__init__ = wrapped_init
399399
return cls
400400

401-
def check_logs(csvlogger):
401+
def check_logs(log_dir):
402402
import os
403403
import time
404-
metr_file_path = csvlogger.experiment.metrics_file_path
404+
metr_file_path = os.path.join(log_dir, 'metrics.csv')
405405
if os.path.exists(metr_file_path):
406406
# we make a backup of the metrics file (rename)
407407
bckp_date = f'{time.strftime("%Y%m%d")}-{time.strftime("%H%M%S")}'

0 commit comments

Comments
 (0)