Closed jeffjennings closed 9 months ago
Yeah, right now I'm defining the torch.optim
optimizer and torch.optim.lr_scheduler
scheduler in CrossValidate
to pass to TrainTest
, but utimately it could be better to have the user pass in a template optimizer and scheduler that are re-initialized per kfold.
For now the scheduler has a threshold that's a relative factor of the loss, so I think it should be pretty flexible. The schedule factor I'm less certain will be flexible, but because of that I made it an arg of CrossValidate
. I'll get a sense by running cross-val on more datasets.
NOTE: This PR should only be reviewed after #206 is merged into
main
andmain
merged into this branch (it was branched from there).Choice and defaults of scheduler: Of the several
torch.optim.lr_scheduler
schedulers,ReduceLROnPlateau
is one of the few that updates the learning rate according to some metric (doc here, not well-written), rather than just at some user-supplied number of epochs (which would not be at all general). I use this scheduler and give it the loss as the metric.The scheduler has a threshold below which it judges the metric to no longer be changing, triggering a decrease in the learning rate; I keep this threshold at the default factor of 1e-4. For the factor to reduce the learning rate to, I found 0.995 is a good choice (reduces the learning rate to 99.5% of its previous value) -- the factor is an arg in the
TrainTest
class. This choice (for the 1 dsharp dataset I tested) keeps gradually reducing the brightness scale of the gradient image after the loss has plateaued by eye, while avoiding transient spikes in the loss at large iteration. The learning rate update is done at each iteration in the training loop afteroptimizer.step()
.Because the scheduler gradually improves the gradient image even when the loss appears to plateau, I've also tested strengthening the convergence tolerance for the loss in the training loop. I've found the best result by setting the tolerance to 1 part in 10^5 (the loss must be changing by less than this for 10 iterations to be considered converged; previously 1 part in 10^3). This factor is an arg in
TrainTest
.