Open AdamCoxson opened 4 months ago
In general, I think your approach is valid and doing such modifications to skorch classes is an absolutely perfect approach.
I wonder if the on_epoch_end
method could not be simplified by using an early return like this:
def on_epoch_end(self, net, **kwargs):
if n_epoch <= self.epoch_start:
print("Not starting lr scheduler yet")
return
return super().on_epoch_end(net, **kwargs)
Regarding SequentialLR
, I haven't checked in detail, but wouldn't something like this work?
lr_scheduler_seq = LRScheduler(policy=SequentialLR, ...) # add more args to SequentialLR here
Check out the docs here.
Hi,
I'm using ReduceLROnPlateau and Early Stopping with my own custom Concordance Correlation monitor metric (RhoC). I want my networks to train with .fit with the callbacks inactivate for the first 50 or 100 epochs, then activate. As far as I'm aware, there aren't arguments to do this, so I created my own version of the lr_scheduler module and modified the
LRScheduler(Callback)
class to take in an argument called epoch_start.I've shown the modified class functions at the bottom. (
def __init__, def kwargs, def _on_epoch_end
)This does the job for me for now. I can do a similar thing to modify EarlyStopping. I just wanted to check in and see if there is actually a way to do this, or a work around without needing to modify the source code. I looked into SequentialLR and could apply it in a similar way as in this post, just with ConstantLR for the first 50 epochs, which would work for normal pyTorch and if I was manually coding my own fit function, but I'm unsure how to integrate SequentialLR with skorch's fit and callbacks system.
So my questions are: 1) Other than modifying source code, how can I add an activation delay to callbacks based on epoch number. If there is a way that already exists with skorchs .fit function. 2) How could I implement SequentialLR or an equivalent set of learning rate schedulers in callbacks?
This is more for interest as modifiying the source code works for me. Any pointers let me know :)
My callbacks are defined like: