Skip to content
5 changes: 4 additions & 1 deletion torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,10 @@ def main():
trainer.fit(model, data, ckpt_path=None if args.reset_trainer else args.load_model)

# run test set after completing the fit
model = LNNP.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
model = LNNP.load_from_checkpoint(
trainer.checkpoint_callback.best_model_path,
hparams_file=f"{args.log_dir}/input.yaml",
)
trainer = pl.Trainer(
logger=_logger,
inference_mode=False,
Expand Down
Loading