Closed Shreeyak closed 4 years ago
Hmm, I'm guessing accumulation_steps
in range_test
would do what I want? Does passing in accumulation_steps
>1 average the loss values, or plot the sum of all loss values accumulated?
If I understand correctly, the propose of running multiple batches in each iteration is to get a better result of lr-loss
curve. Precisely, the better result can make you pick a proper learning rate to make the model be trained better.
However, it would be less meaningful to do this since each batch is usually sampled randomly by the train_loader
(DataLoader(..., shuffle=True, ...)
). That means you are not able to manually control which batch of data is going to be used in an iteration, and it also makes sense that you shouldn't do this in general cases. Usually, we expect batches of data won't be fed into a model in the same order in each training epoch. You can check out this awesome post to know further.
So, in this case, I would suggest using Leslie Smith's approach, i.e. run the range_test()
with val_loader
, if you require a better lr-loss
curve. Or you can try to rewrite LRFinder._train_batch()
if you want to do some experiments for this.
Besides, accumulation_steps
is not for this propose. It's designed to simulate larger batch size with a few small batch size (a.k.a gradient accumulation). So that's probably not the feature you want to used for this.
So you're saying effectively that passing multiple batches and averaging the loss won't really improve the precision of the plots? Probably true. And if so, then this feature request is pointless.
Good point about accumulation_steps
. That'll be handy to simulate the larger batches I'll be running on the server on my local PC.
How do I get the lr_finder to run multiple batches for each "iteration"? Logic being that running multiple batches would give a more precise result. Based on the naming, I'd assumed that
num_iter
inlr_finder.range_test()
would control the number of batches/iterations for each value of lr in the given range. However,num_iter
controls the number of unique lr values to test within the given interval, running only 1 batch through the network.