havakv / pycox

Survival analysis with PyTorch
BSD 2-Clause "Simplified" License
803 stars 188 forks source link

Can I use lr_scheduler of Pytorch in model.fit ? #103

Closed kyuchoi closed 2 years ago

kyuchoi commented 2 years ago

First of all, thank you for your great works !!

Can I use lr_scheduler of Pytorch in model.fit ? If it's possible, then how can I modify the following code to use the scheduler for learning rate?

log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val)

Many thanks

havakv commented 2 years ago

Thank you for the kind words. To be able to use learning rate schedulers you would need to implement them as callbacks. There is an example/explanation here for how to do this for gradients. So the most basic implementation would look something like this

from torchtuples.callbacks import Callback

class LRScheduler(Callback):
    def __init__(self, scheduler):
        self.scheduler = scheduler

    def on_epoch_end(self):
        self.scheduler.step()

And in your example code above, you would have

scheduler = <define-some-scheduler>
callbacks = [<other-callbacks>, LRScheduler(scheduler)]

If you want a learning rate scheduler that depends on some metric that would also need to be included in the callback. You can for instance access the training loss by self.model.batch_loss in the callback.

Does this answer your question?

kyuchoi commented 2 years ago

It works !! Thank you so much.