Open kklemon opened 1 year ago
@kklemon As an intermediate solution, there is trainer.save_checkpoint
which the user can call in e.g. the on_train_start()
hook or similar. Not ideal because one has to pass in the path where to save. Just wanted to mention this alternative in case this helps unblock you.
One question I have though about your motivation: If we wanted to evaluate a baseline based on a random initialization, wouldn't we just initialize the model normally and run our baseline directly, on the untrained model?
To give some additional info, the callback is designed to not save when no steps have been taken: https://github.com/Lightning-AI/lightning/blob/d48ec08d76ec090fb0836b5da7f8f9d136f85426/src/lightning/pytorch/callbacks/model_checkpoint.py#L229
@awaelchli
As an intermediate solution, there is trainer.save_checkpoint which the user can call in e.g. the on_train_start() hook or similar.
At the moment, I just inherit from ModelCheckpoint
and trigger a checkpoint save from on_train_start
.
One question I have though about your motivation: If we wanted to evaluate a baseline based on a random initialization, wouldn't we just initialize the model normally and run our baseline directly, on the untrained model?
I see model checkpoints, and in particular the LightningDataModule.load_from_checkpoint(...)
mechanic as a nice and easy way to decouple downstream applications or inference from the model implementation and initialization details. The requested feature would allow an untrained state of a model to be loaded with the usual checkpoint loading API, without having to consider model hyperparameters or initialization specifics.
But I agree that to some extent this is certainly a rather specific use case.
Description & Motivation
Sometimes it can be useful to have access to the randomly initialized weights of a model prior to training. For instance, when evaluating against downstream tasks the untrained model can serve as a random baseline for comparison, especially since many publications are showing that the architectural prior of a randomly initialized model exhibits representational structures that can transfer to downstream applications without any training.
The current
ModelCheckpoint
implementation does not offer an option to allow saving the initial state of a model before any optimization is performed.Pitch
The
ModelCheckpoint
callback could be extended by asave_initial_weights
option that would control whether the initial model is checkpointed prior to training. The flag would default toFalse
to prevent breaking behaviour.Alternatives
ModelCheckpoint
. Works of course, but would be nice to see this feature as part of PyTorch Lightning.Additional context
No response
cc @borda @awaelchli @carmocca