luigibonati / mlcolvar

A unified framework for machine learning collective variables for enhanced sampling simulations
MIT License
91 stars 24 forks source link

DeepTDA cannot be loaded from checkpoint #103

Closed luigibonati closed 8 months ago

luigibonati commented 9 months ago

Loading a DeepTDA CV from a checkpoint does not work:

Minimal (non)working example:

from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
checkpoint = ModelCheckpoint(save_top_k=1,  monitor="valid_loss")

trainer = pl.Trainer(callbacks=[checkpoint],enable_checkpointing=True)
trainer.fit( model, datamodule )

best_model = DeepTDA.load_from_checkpoint(checkpoint.best_model_path)

given an error in initialization:

File [~/software/mambaforge/envs/mlcolvar/lib/python3.10/site-packages/mlcolvar/cvs/supervised/deeptda.py:67](https://file+.vscode-resource.vscode-cdn.net/home/lbonati%40iit.local/work/simulations/sampl5/OAMe/G2/~/software/mambaforge/envs/mlcolvar/lib/python3.10/site-packages/mlcolvar/cvs/supervised/deeptda.py:67), in DeepTDA.__init__(self, n_states, n_cvs, target_centers, target_sigmas, layers, options, **kwargs)
     35 def __init__(
     36     self,
     37     n_states: int,
   (...)
     43     **kwargs,
     44 ):
     45     """
     46     Define Deep Targeted Discriminant Analysis (Deep-TDA) CV composed by a neural network module.
     47     By default a module standardizing the inputs is also used.
   (...)
     64         Set 'block_name' = None or False to turn off that block
     65     """
---> 67     super().__init__(in_features=layers[0], out_features=layers[-1], **kwargs)
     69     # =======   LOSS  =======
     70     self.loss_fn = TDALoss(
     71         n_states=n_states,
     72         target_centers=target_centers,
     73         target_sigmas=target_sigmas,
     74     )

TypeError: mlcolvar.cvs.cv.BaseCV.__init__() got multiple values for keyword argument 'in_features'

we should also check the other CVs and add regtests for this feature (as of now only regressionCV was tested in this notebook: https://mlcolvar.readthedocs.io/en/stable/notebooks/tutorials/intro_3_loss_optim.html#Model-checkpointing)

luigibonati commented 8 months ago

@andrrizzi we looked into it, the problem is that when loading a checkpoint: super().__init__(in_features=layers[0], out_features=layers[-1], **kwargs)

kwargs contains also in_features and out_features. If we delete those keys before calling the init of the mother class it works:

        kwargs.pop("in_features", None)
        kwargs.pop("out_features", None)
        super().__init__(in_features=layers[0], out_features=layers[-1], **kwargs)

but what I don't like is that we need to do this in every class that inherits from BaseCV.. do you have any suggestion?

andrrizzi commented 8 months ago

If all the inherited CVs explicitly pass in/out_features to BaseCV.__init__ based on some other init argument, an alternative might be to modify BaseCV.__init__ to call self.save_parameters(ignore=['in_features', 'out_features']). I'm not sure, but I seem to remember that only saved parameters are then restored from the checkpoint.

If only a handful are doing it, then we might add that save_parameters(ignore=...) bit individually in the their __init__.