skorch-dev / skorch

A scikit-learn compatible neural network library that wraps PyTorch
BSD 3-Clause "New" or "Revised" License
5.69k stars 384 forks source link

Activating (deactivating) callbacks at specific epochs or milestones and SequentialLR #1049

Open AdamCoxson opened 4 months ago

AdamCoxson commented 4 months ago

Hi,

I'm using ReduceLROnPlateau and Early Stopping with my own custom Concordance Correlation monitor metric (RhoC). I want my networks to train with .fit with the callbacks inactivate for the first 50 or 100 epochs, then activate. As far as I'm aware, there aren't arguments to do this, so I created my own version of the lr_scheduler module and modified the LRScheduler(Callback) class to take in an argument called epoch_start.

I've shown the modified class functions at the bottom. (def __init__, def kwargs, def _on_epoch_end)

This does the job for me for now. I can do a similar thing to modify EarlyStopping. I just wanted to check in and see if there is actually a way to do this, or a work around without needing to modify the source code. I looked into SequentialLR and could apply it in a similar way as in this post, just with ConstantLR for the first 50 epochs, which would work for normal pyTorch and if I was manually coding my own fit function, but I'm unsure how to integrate SequentialLR with skorch's fit and callbacks system.

So my questions are: 1) Other than modifying source code, how can I add an activation delay to callbacks based on epoch number. If there is a way that already exists with skorchs .fit function. 2) How could I implement SequentialLR or an equivalent set of learning rate schedulers in callbacks?

This is more for interest as modifiying the source code works for me. Any pointers let me know :)

    def __init__(self,
                 policy='WarmRestartLR',
                 monitor='train_loss',
                 event_name="event_lr",
                 step_every='epoch',
                 epoch_start=1, 
                 **kwargs):
        self.policy = policy
        self.monitor = monitor
        self.event_name = event_name
        self.step_every = step_every
        self.epoch_start=epoch_start
        # if 'epoch_start' in kwargs:
        #     del kwargs['epoch_start']
        vars(self).update(kwargs)

  def kwargs(self):
      # These are the parameters that are passed to the
      # scheduler. Parameters that don't belong there must be
      # excluded.
      excluded = ('policy', 'monitor', 'event_name', 'step_every', 'epoch_start')
      kwargs = {key: val for key, val in vars(self).items()
                if not (key in excluded or key.endswith('_'))}
      return kwargs`

def on_epoch_end(self, net, **kwargs):
        if self.step_every != 'epoch':
            return
        if isinstance(self.lr_scheduler_, ReduceLROnPlateau):
            if callable(self.monitor):
                score = self.monitor(net)
            else:
                try:
                    score = net.history[-1, self.monitor]
                except KeyError as e:
                    raise ValueError(
                        f"'{self.monitor}' was not found in history. A "
                        f"Scoring callback with name='{self.monitor}' "
                        "should be placed before the LRScheduler callback"
                    ) from e
            n_epoch=len(net.history)
            if n_epoch <= self.epoch_start:
                print("Not starting lr scheduler yet")
                return
            else:
                self._step(net, self.lr_scheduler_, score=score)
            # ReduceLROnPlateau does not expose the current lr so it can't be recorded
        else:
            if (
                    (self.event_name is not None)
                    and hasattr(self.lr_scheduler_, "get_last_lr")
            ):
                net.history.record(self.event_name, self.lr_scheduler_.get_last_lr()[0])
            self._step(net, self.lr_scheduler_)

My callbacks are defined like:

r_score = EpochScoring(scoring=rhoc_score, lower_is_better=False, name='valid_rhoc')
r_early_stop = EarlyStopping(monitor='valid_rhoc',patience=30, lower_is_better=False, threshold_mode='abs',threshold=1e-4,load_best=True)
lr_plateau = LRScheduler(epoch_start=50,monitor='valid_loss',policy='ReduceLROnPlateau',factor=0.25,patience=10, threshold=1e-3, verbose=True)
#scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2])
callbacks=[r_score, r_early_stop, lr_plateau]
BenjaminBossan commented 4 months ago

In general, I think your approach is valid and doing such modifications to skorch classes is an absolutely perfect approach.

I wonder if the on_epoch_end method could not be simplified by using an early return like this:

        def on_epoch_end(self, net, **kwargs):
            if n_epoch <= self.epoch_start:
                print("Not starting lr scheduler yet")
                return
            return super().on_epoch_end(net, **kwargs)

Regarding SequentialLR, I haven't checked in detail, but wouldn't something like this work?

lr_scheduler_seq = LRScheduler(policy=SequentialLR, ...)  # add more args to SequentialLR here

Check out the docs here.