Closed AMHermansen closed 10 months ago
Hey @AMHermansen! Thanks for giving this a crack.
After spending some time thinking about this, I realized that we can solve this without adding a custom callback class.
in StandardModel.fit
we process the callbacks
argument - and we have functionality that creates the callbacks in case none is given. This part of the code base is old and when I reviewed it today I realized it's unnecessarily complicated -> complexity arises from trying to infer if early stopping was given and if a validation loader is present. I think this is an ideal place to add a ModelCheckpoint callback on behalf of the user while giving this part of the code a little love.
We introduce minor changes to .fit
like so:
def fit(
self,
train_dataloader: DataLoader,
val_dataloader: Optional[DataLoader] = None,
*,
max_epochs: int = 10,
early_stopping_patience: int = 5,
gpus: Optional[Union[List[int], int]] = None,
callbacks: Optional[List[Callback]] = None,
ckpt_path: Optional[str] = None,
logger: Optional[LightningLogger] = None,
log_every_n_steps: int = 1,
gradient_clip_val: Optional[float] = None,
distribution_strategy: Optional[str] = "ddp",
**trainer_kwargs: Any,
) -> None:
"""Fit `StandardModel` using `pytorch_lightning.Trainer`."""
# Checks
if callbacks is None:
# We create the bare-minimum callbacks for you.
callbacks = self._create_default_callbacks(
val_dataloader=val_dataloader,
)
else:
# You are on your own!
# We just add the progressbar if you forgot it.
has_progress_bar = False
for callback in callbacks:
if isinstance(callback, ProgressBar):
has_progress_bar = True
if has_progress_bar is False:
callbacks.append(ProgressBar())
has_early_stopping = self._has_early_stopping(callbacks)
has_model_checkpoint = self._has_model_checkpoint(callbacks)
self.train(mode=True)
trainer = self._construct_trainer(
max_epochs=max_epochs,
gpus=gpus,
callbacks=callbacks,
logger=logger,
log_every_n_steps=log_every_n_steps,
gradient_clip_val=gradient_clip_val,
distribution_strategy=distribution_strategy,
**trainer_kwargs,
)
try:
trainer.fit(
self, train_dataloader, val_dataloader, ckpt_path=ckpt_path
)
except KeyboardInterrupt:
self.warning("[ctrl+c] Exiting gracefully.")
pass
# Load weights from best-fit model after training if possible
if has_early_stopping & has_model_checkpoint:
for callback in callbacks:
if isinstance(callback, ModelCheckpoint):
checkpoint_callback = callback
self.load_state_dict(torch.load(checkpoint_callback.best_model_path)['state_dict'])
else:
# raise informative warning
The idea here is to toggle between two cases; either the user gave no callbacks or the user did. If none is given, we infer bare-minimum callbacks based on the other arguments given. Notice this introduces the new argument early_stopping_patience
. If callbacks are given, the user is "on its own". This allows us to simplify the code a little bit.
After the training is finished, we check if has_early_stopping & has_model_checkpoint
and load in the best-fit model parameters if possible. This would result in "expected" behavior of model.fit
when validation loader is given, and would allow us to shave off a few lines of code in the example scripts because specifying callbacks is not needed. @AMHermansen what do you think?
The default callback function could then be simplified to:
def _create_default_callbacks(self,
val_dataloader: DataLoader,
early_stopping_patience: int) -> List:
""" Create default callbacks.
Used in cases where no callbacks are specified by the user in .fit"""
callbacks = [ProgressBar()]
if val_dataloader is not None:
# Add Early Stopping
callbacks.append(EarlyStopping(
monitor="val_loss",
patience=early_stopping_patience,
))
# Add Model Check Point
callbacks.append(ModelCheckpoint(save_top_k=1, monitor="val_loss", mode="min",
filename= f"{self._gnn.__class__.__name__}"+'-{epoch}-{val_loss:.2f}-{train_loss:.2f}'))
self.info(f'EarlyStopping has been added with a patience of {early_stopping_patience}.')
return callbacks
Hello @RasmusOrsoe Thank you for your input.
I think your solution also works. I'm personally not a big fan of "enforcing" non-essential default callbacks, but I understand if you would prefer to have default callbacks, to reduce "boilerplate-y"-code in the training scripts.
I think if you end up going with adding a mixture of EarlyStopping
and ModelCheckpoint
as default callbacks. Then you should make sure the model logs the checkpoint file-path used.
@AMHermansen as mentioned in the call today, I think we should offer this callback as a compliment to the changes I proposed earlier. I left a few minor comments. I'll introduce my proposed changes in a separate PR.
Can you confirm that the callback works as intended?
I can confirm that it works as intended. I tried a slightly modified version of it, that also logged in a verbose manner whenever it saved something to disk, and it was doing it corretly (i.e. only at epochs where it achieved an improved validation loss)
Adds an early-stopping callback, which also saves and loads the best weights, along with the model config. For the design I've added an outdir, to save the
state_dict/model_config
to, so I'm not sure if it can replace the default early stopping.I couldn't come up with a good descriptive name for the callback, that wasn't overly verbose, but I'm more than open to suggestions.