MPoL-dev / MPoL

A flexible Python platform for Regularized Maximum Likelihood imaging
https://mpol-dev.github.io/MPoL/
MIT License
33 stars 11 forks source link

Learn rate scheduler #207

Closed jeffjennings closed 9 months ago

jeffjennings commented 9 months ago

NOTE: This PR should only be reviewed after #206 is merged into main and main merged into this branch (it was branched from there).

Choice and defaults of scheduler: Of the several torch.optim.lr_scheduler schedulers, ReduceLROnPlateau is one of the few that updates the learning rate according to some metric (doc here, not well-written), rather than just at some user-supplied number of epochs (which would not be at all general). I use this scheduler and give it the loss as the metric.

The scheduler has a threshold below which it judges the metric to no longer be changing, triggering a decrease in the learning rate; I keep this threshold at the default factor of 1e-4. For the factor to reduce the learning rate to, I found 0.995 is a good choice (reduces the learning rate to 99.5% of its previous value) -- the factor is an arg in the TrainTest class. This choice (for the 1 dsharp dataset I tested) keeps gradually reducing the brightness scale of the gradient image after the loss has plateaued by eye, while avoiding transient spikes in the loss at large iteration. The learning rate update is done at each iteration in the training loop after optimizer.step().

Because the scheduler gradually improves the gradient image even when the loss appears to plateau, I've also tested strengthening the convergence tolerance for the loss in the training loop. I've found the best result by setting the tolerance to 1 part in 10^5 (the loss must be changing by less than this for 10 iterations to be considered converged; previously 1 part in 10^3). This factor is an arg in TrainTest.

jeffjennings commented 9 months ago

Yeah, right now I'm defining the torch.optim optimizer and torch.optim.lr_scheduler scheduler in CrossValidate to pass to TrainTest, but utimately it could be better to have the user pass in a template optimizer and scheduler that are re-initialized per kfold.

For now the scheduler has a threshold that's a relative factor of the loss, so I think it should be pretty flexible. The schedule factor I'm less certain will be flexible, but because of that I made it an arg of CrossValidate. I'll get a sense by running cross-val on more datasets.