Closed nagapavan525 closed 4 years ago
Hi @nagapavan525, you can try to create a wrapper for loss function and add more desired metric functions into it, then store the result in that instance.
class LossWrapper(nn.Module):
def __init__(self, loss_func, extra_metrics=None):
super(LossWrapper, self).__init__()
self.loss_func = loss_func
self.extra_metrics = extra_metrics
if extra_metrics is not None:
self.extra_metrics_log = {self._get_metric_name(fn): [] for fn in extra_metrics}
else:
self.extra_metrics_log = None
def _get_metric_name(self, fn):
return fn._get_name()
# or use this instead since `_get_name()` is a custom function made by pytorch
# return func.__class__.__qualname__
def forward(self, inputs, labels):
loss = self.loss_func(inputs, labels)
# evaluate more desired metrics
with torch.no_grad():
for fn in self.extra_metrics:
fn_name = self._get_metric_name(fn)
self.extra_metrics_log[fn_name].append(fn(inputs, labels))
return loss
# Run LRFinder, take MNIST dataset for example
def main():
# ... other setups for dataset, model
# NOTE: Here we use a loss wrapper instead, and we use `NLLLoss` as one of
# `extra_metrics` as a demonstration. So that there should be no
# difference between plotted curves.
loss_wrapper = LossWrapper(nn.NLLLoss(), extra_metrics=[nn.NLLLoss()])
criterion = loss_wrapper
lr_finder = LRFinder(model, optimizer, criterion, device='cuda')
# NOTE: Here we set `smooth_f` to 0 so that those 2 loss curves can be
# visually distinguished after showing the figure.
lr_finder.range_test(
trainloader, end_lr=0.1, num_iter=100, step_mode='exp', smooth_f=0
)
fig, ax = plt.subplots()
# NOTE: Set `skip_start` and `skip_end` to 0 because we want to plot the whole series of loss
lr_finder.plot(ax=ax, skip_start=0, skip_end=0, suggest_lr=False)
for i, (metric_name, metric_log) in enumerate(loss_wrapper.extra_metrics_log.items()):
np_metric_log = torch.stack(metric_log).cpu().numpy()
ax.plot(lr_finder.history['lr'], np_metric_log, linewidth=2+i, zorder=1+i)
plt.show()
if __name__ == '__main__':
main()
And here is the result I got:
If you need to implement something requiring more complicate control, you might need to create a subclass inheriting from LRFinder
and override _train_batch()
to handle the process for calculating loss and metrics. See also this part:
Feel free to let me know if you still have problems!
Hi @NaleRaphael NaleRaphael,
Thank you for your suggestion. Looks like creating a wrapper for loss function suffices my need.
Regards, Naga Pavan
Hi David @davidtvs,
Is there a way we can plot the learning rate vs accuracy and get LR at max accuracy using your library?
I am trying to use SuperConvergence(https://arxiv.org/pdf/1708.07120.pdf) by Leslie N. Smith. So I am using PyTorch's OneCycleLR scheduler for this. And it is expecting max_lr value.
I used your lr-finder but it is plotting between loss curve and learning rates and suggesting LR at steepest descent. But I am looking for learning rate vs accuracy and get LR at maximum accuracy.
Please suggest to me.
Thanks in advance, Naga Pavan