nmichlo / disent

🧶 Modular VAE disentanglement framework for python built with PyTorch Lightning ▸ Including metrics and datasets ▸ With strongly supervised, weakly supervised and unsupervised methods ▸ Easily configured and run with Hydra config ▸ Inspired by disentanglement_lib
https://disent.michlo.dev
MIT License
122 stars 18 forks source link

[FEATURE]: Model Saving and Checkpointing #28

Closed nmichlo closed 1 year ago

nmichlo commented 2 years ago

Is your feature request related to a problem? Please describe. Model saving and checkpointing is currently disabled for experiment/run.py This was due to old pickling errors and the extensive use of wandb for logging. Actual saved models were not needed at the time.

Describe the solution you'd like Re-enable model checkpointing, and allow continuing of training.

nmichlo commented 2 years ago

Fixed optimiser pickling in 82fa093adaac2d9d0fb9f56168af2c0b5a073872

TODO: add checkpointing

meffmadd commented 1 year ago

Hi 👋 First of all thanks for the fantastic framework!

I started to work on this issue since I need model saving for my visualization project. Model saving seems to be very simple, however, loading a checkpoint appears to be a bit more complicated as far as I understand. I found that a saved checkpoint can be loaded but the model parameter for e.g. BetaVae is an omegaconf.dictconfig.DictConfig instead of an actual AutoEncoder object if run with hydra. I think this is because in run.py the framework is created using hydra and maybe some magic is happening there. It works fine if I save and load the models with the standard Python API like in your example. Will investigate...

meffmadd commented 1 year ago

Ah, it seems like someone hit the same roadblock before: https://github.com/Lightning-AI/lightning/discussions/6144

nmichlo commented 1 year ago

Hi @meffmadd, really glad you are finding it useful!

I managed to get away with wandb results a while back so I never got around to fixing this.

I'll investigate based on the information you provided and get back to you. Thank you for that!

nmichlo commented 1 year ago

You noted the object is an OmegaConf instance. I don't think it would break too much if we switch that over to a dictionary, and recursively convert all the values. There is a built in function for this.

(As for pytorch lightning, I have become a bit disillusioned towards it, as it has placed certain constraints on the framework that were never intended.)

To get to my question, how important is API stability for you right now?

meffmadd commented 1 year ago

API stability it is not a major concern for me. I only planned on using the hydra configs but maybe using the Python API is more useful for me (since it removes a layer of complexity). Yeah, frameworks are nice as long as their magic works and there is documentation 😅

For model saving when running with hydra I think I found a simple workaround in the Vae class:

    def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        checkpoint["hyper_parameters"]["model"] = self._model
        return super().on_save_checkpoint(checkpoint)

This manually sets the model in the checkpoint and also works for loading! I can implement simple model saving and make a pull request with this if you like so no API changes are necessary.

And a quick unrelated question: Could you tell me a config for beta-VAE that works well with dSprites? I use bce loss but this somehow creates NaN values in the encoder output. If I use the norm_conv64 framework it works well but this seems to be not be standard as per your warning.

meffmadd commented 1 year ago

Created pull request #37

nmichlo commented 1 year ago

Thank you so much for the PR! I left a few comments about tests. We just need to make sure to add the new keys to the configs and (possibly) update the tests to tests the checkpointing.


As for your question. That should not be happening with the BCE loss. It may be due to largely unrelated things like the strength of the regularization term, or the learning rate too.

EDIT: on this note another reason for the BCE loss failing could be due to the dataset normalization. I am not sure if that possibly has a part to play, as the output is also normalized. There may be a logic/precision error there.

meffmadd commented 1 year ago

I will fix the configs now!

Thanks for your answer! I will try it with MSE but with a higher learning rate because when I tested the beta-VAE with MSE it did not converge at all.

nmichlo commented 1 year ago

I think possibly a lower beta value then too.

nmichlo commented 1 year ago

Closing this with your changes from:

Thank you for contributing!

Now released under v0.7.0