Skip to content
7 changes: 6 additions & 1 deletion torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torchmdnet.utils import LoadFromFile, LoadFromCheckpoint, save_argparse, number
import torch


def get_args():
# fmt: off
parser = argparse.ArgumentParser(description='Training')
Expand Down Expand Up @@ -179,8 +180,12 @@ def main():
trainer.fit(model, data)

# run test set after completing the fit
model = LNNP.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
model = LNNP.load_from_checkpoint(
trainer.checkpoint_callback.best_model_path,
hparams_file=f"{args.log_dir}/input.yaml",
)
trainer = pl.Trainer(logger=_logger)

trainer.test(model, data)


Expand Down