Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.04k stars 3.36k forks source link

Allow checkpointing initial model weights #17469

Open kklemon opened 1 year ago

kklemon commented 1 year ago

Description & Motivation

Sometimes it can be useful to have access to the randomly initialized weights of a model prior to training. For instance, when evaluating against downstream tasks the untrained model can serve as a random baseline for comparison, especially since many publications are showing that the architectural prior of a randomly initialized model exhibits representational structures that can transfer to downstream applications without any training.

The current ModelCheckpoint implementation does not offer an option to allow saving the initial state of a model before any optimization is performed.

Pitch

The ModelCheckpoint callback could be extended by a save_initial_weights option that would control whether the initial model is checkpointed prior to training. The flag would default to False to prevent breaking behaviour.

Alternatives

Additional context

No response

cc @borda @awaelchli @carmocca

awaelchli commented 1 year ago

@kklemon As an intermediate solution, there is trainer.save_checkpoint which the user can call in e.g. the on_train_start() hook or similar. Not ideal because one has to pass in the path where to save. Just wanted to mention this alternative in case this helps unblock you.

One question I have though about your motivation: If we wanted to evaluate a baseline based on a random initialization, wouldn't we just initialize the model normally and run our baseline directly, on the untrained model?

carmocca commented 1 year ago

To give some additional info, the callback is designed to not save when no steps have been taken: https://github.com/Lightning-AI/lightning/blob/d48ec08d76ec090fb0836b5da7f8f9d136f85426/src/lightning/pytorch/callbacks/model_checkpoint.py#L229

kklemon commented 1 year ago

@awaelchli

As an intermediate solution, there is trainer.save_checkpoint which the user can call in e.g. the on_train_start() hook or similar.

At the moment, I just inherit from ModelCheckpoint and trigger a checkpoint save from on_train_start.

One question I have though about your motivation: If we wanted to evaluate a baseline based on a random initialization, wouldn't we just initialize the model normally and run our baseline directly, on the untrained model?

I see model checkpoints, and in particular the LightningDataModule.load_from_checkpoint(...) mechanic as a nice and easy way to decouple downstream applications or inference from the model implementation and initialization details. The requested feature would allow an untrained state of a model to be loaded with the usual checkpoint loading API, without having to consider model hyperparameters or initialization specifics.

But I agree that to some extent this is certainly a rather specific use case.