atomistic-machine-learning / schnetpack

SchNetPack - Deep Neural Networks for Atomistic Systems
Other
751 stars 210 forks source link

Testing a trained model #641

Open MasterLucas opened 2 weeks ago

MasterLucas commented 2 weeks ago

Dear Sir/ Madam,

I have been having some problems loading and testing a trained model. I trained a set of models in one job and saved their checkpoints, but when trying to load and run them in another job, I got this error

TypeError: `Trainer.test()` requires a `LightningModule` when it hasn't been passed in a previous run

This is the code I'm using:

def testing_model(dataset, seed=42, data_path="./") -> None:
    torch.manual_seed(seed)
    model_dir = f"model_data"

    log_test_file = os.path.join(model_dir, "test_loss_log.txt")

    logger = pl.loggers.TensorBoardLogger(save_dir=model_dir)

    callbacks = spk.train.ModelCheckpoint(
            model_path=os.path.join(model_dir, "best_inference_model"))

    loss_logger = LossLogger(log_test_file=log_test_file)  # Add custom callback for logging loss

    trainer = pl.Trainer(callbacks=[callbacks, loss_logger],
                         logger=logger,
                         default_root_dir=model_dir,
                         accelerator= 'gpu' if torch.cuda.is_available() else 'cpu')

    print("Testing model.")
    trainer.test(datamodule=dataset, ckpt_path=f'{model_dir}/best_inference_model')

testing_model(dataset)

Is this the correct way to test a trained model using SchNetpack?

Yours faithfully,

Lucas Bandeira

jnsLs commented 2 weeks ago

Dear Lucas,

are you using the latest version (commit) of SchNetPack?

Best, Jonas

MasterLucas commented 1 week ago

Dear Jonas,

I tried to check the SchNetPack version I am using but was unable to. I created the environment a few months ago, so I don't know if a new version has been released since then.

jnsLs commented 1 week ago

Do you remember how you installed it? pypi or from source

jnsLs commented 6 days ago

Hi Lucas,

have you figured out the problem yet? If not, the problem here is most likely, that you are loading a torch.nn.Module but the pytorch-lightning Trainer.test method requires a pytorch_lightning.LightningModule

Best, Jonas