graphnet-team / graphnet

A Deep learning library for neutrino telescopes
https://graphnet-team.github.io/graphnet/
Apache License 2.0
90 stars 92 forks source link

Add improved earlystopping #632

Closed AMHermansen closed 10 months ago

AMHermansen commented 10 months ago

Adds an early-stopping callback, which also saves and loads the best weights, along with the model config. For the design I've added an outdir, to save the state_dict/model_config to, so I'm not sure if it can replace the default early stopping.

I couldn't come up with a good descriptive name for the callback, that wasn't overly verbose, but I'm more than open to suggestions.

RasmusOrsoe commented 10 months ago

Hey @AMHermansen! Thanks for giving this a crack.

After spending some time thinking about this, I realized that we can solve this without adding a custom callback class.

in StandardModel.fit we process the callbacks argument - and we have functionality that creates the callbacks in case none is given. This part of the code base is old and when I reviewed it today I realized it's unnecessarily complicated -> complexity arises from trying to infer if early stopping was given and if a validation loader is present. I think this is an ideal place to add a ModelCheckpoint callback on behalf of the user while giving this part of the code a little love.

We introduce minor changes to .fit like so:

def fit(
        self,
        train_dataloader: DataLoader,
        val_dataloader: Optional[DataLoader] = None,
        *,
        max_epochs: int = 10,
        early_stopping_patience: int = 5,
        gpus: Optional[Union[List[int], int]] = None,
        callbacks: Optional[List[Callback]] = None,
        ckpt_path: Optional[str] = None,
        logger: Optional[LightningLogger] = None,
        log_every_n_steps: int = 1,
        gradient_clip_val: Optional[float] = None,
        distribution_strategy: Optional[str] = "ddp",
        **trainer_kwargs: Any,
    ) -> None:
    """Fit `StandardModel` using `pytorch_lightning.Trainer`."""
    # Checks
    if callbacks is None:
        # We create the bare-minimum callbacks for you.
        callbacks = self._create_default_callbacks(
            val_dataloader=val_dataloader,
        )
    else:
        # You are on your own!
        # We just add the progressbar if you forgot it.
        has_progress_bar = False
        for callback in callbacks:
            if isinstance(callback, ProgressBar):
                has_progress_bar = True
        if has_progress_bar is False:
            callbacks.append(ProgressBar())

    has_early_stopping = self._has_early_stopping(callbacks)
    has_model_checkpoint = self._has_model_checkpoint(callbacks)

    self.train(mode=True)
    trainer = self._construct_trainer(
        max_epochs=max_epochs,
        gpus=gpus,
        callbacks=callbacks,
        logger=logger,
        log_every_n_steps=log_every_n_steps,
        gradient_clip_val=gradient_clip_val,
        distribution_strategy=distribution_strategy,
        **trainer_kwargs,
    )

    try:
        trainer.fit(
            self, train_dataloader, val_dataloader, ckpt_path=ckpt_path
        )
    except KeyboardInterrupt:
        self.warning("[ctrl+c] Exiting gracefully.")
        pass

    # Load weights from best-fit model after training if possible
    if has_early_stopping & has_model_checkpoint:
        for callback in callbacks:
            if isinstance(callback, ModelCheckpoint):
                checkpoint_callback = callback 
        self.load_state_dict(torch.load(checkpoint_callback.best_model_path)['state_dict'])
   else: 
        # raise informative warning

The idea here is to toggle between two cases; either the user gave no callbacks or the user did. If none is given, we infer bare-minimum callbacks based on the other arguments given. Notice this introduces the new argument early_stopping_patience. If callbacks are given, the user is "on its own". This allows us to simplify the code a little bit.

After the training is finished, we check if has_early_stopping & has_model_checkpoint and load in the best-fit model parameters if possible. This would result in "expected" behavior of model.fitwhen validation loader is given, and would allow us to shave off a few lines of code in the example scripts because specifying callbacks is not needed. @AMHermansen what do you think?

The default callback function could then be simplified to:

def _create_default_callbacks(self, 
                                  val_dataloader: DataLoader, 
                                  early_stopping_patience: int) -> List:
    """ Create default callbacks.

    Used in cases where no callbacks are specified by the user in .fit"""
    callbacks = [ProgressBar()]
    if val_dataloader is not None:
        # Add Early Stopping
        callbacks.append(EarlyStopping(
                            monitor="val_loss",
                            patience=early_stopping_patience,
                        ))
        # Add Model Check Point
        callbacks.append(ModelCheckpoint(save_top_k=1, monitor="val_loss", mode="min",
                        filename= f"{self._gnn.__class__.__name__}"+'-{epoch}-{val_loss:.2f}-{train_loss:.2f}'))
        self.info(f'EarlyStopping has been added with a patience of {early_stopping_patience}.')
    return callbacks   
AMHermansen commented 10 months ago

Hello @RasmusOrsoe Thank you for your input.

I think your solution also works. I'm personally not a big fan of "enforcing" non-essential default callbacks, but I understand if you would prefer to have default callbacks, to reduce "boilerplate-y"-code in the training scripts.

I think if you end up going with adding a mixture of EarlyStopping and ModelCheckpoint as default callbacks. Then you should make sure the model logs the checkpoint file-path used.

AMHermansen commented 10 months ago

@AMHermansen as mentioned in the call today, I think we should offer this callback as a compliment to the changes I proposed earlier. I left a few minor comments. I'll introduce my proposed changes in a separate PR.

Can you confirm that the callback works as intended?

I can confirm that it works as intended. I tried a slightly modified version of it, that also logged in a verbose manner whenever it saved something to disk, and it was doing it corretly (i.e. only at epochs where it achieved an improved validation loss)