davidtvs / pytorch-lr-finder

A learning rate range test implementation in PyTorch
MIT License
921 stars 120 forks source link

Can we plot the learning rate vs accuracy and get LR at max accuracy using your library - need for SuperConvergence #63

Closed nagapavan525 closed 4 years ago

nagapavan525 commented 4 years ago

Hi David @davidtvs,

Is there a way we can plot the learning rate vs accuracy and get LR at max accuracy using your library?

I am trying to use SuperConvergence(https://arxiv.org/pdf/1708.07120.pdf) by Leslie N. Smith. So I am using PyTorch's OneCycleLR scheduler for this. And it is expecting max_lr value.

I used your lr-finder but it is plotting between loss curve and learning rates and suggesting LR at steepest descent. But I am looking for learning rate vs accuracy and get LR at maximum accuracy.

Please suggest to me.

Thanks in advance, Naga Pavan

NaleRaphael commented 4 years ago

Hi @nagapavan525, you can try to create a wrapper for loss function and add more desired metric functions into it, then store the result in that instance.

class LossWrapper(nn.Module):
    def __init__(self, loss_func, extra_metrics=None):
        super(LossWrapper, self).__init__()
        self.loss_func = loss_func
        self.extra_metrics = extra_metrics
        if extra_metrics is not None:
            self.extra_metrics_log = {self._get_metric_name(fn): [] for fn in extra_metrics}
        else:
            self.extra_metrics_log = None

    def _get_metric_name(self, fn):
        return fn._get_name()
        # or use this instead since `_get_name()` is a custom function made by pytorch
        # return func.__class__.__qualname__

    def forward(self, inputs, labels):
        loss = self.loss_func(inputs, labels)

        # evaluate more desired metrics
        with torch.no_grad():
            for fn in self.extra_metrics:
                fn_name = self._get_metric_name(fn)
                self.extra_metrics_log[fn_name].append(fn(inputs, labels))
        return loss

# Run LRFinder, take MNIST dataset for example
def main():
    # ... other setups for dataset, model

    # NOTE: Here we use a loss wrapper instead, and we use `NLLLoss` as one of
    # `extra_metrics` as a demonstration. So that there should be no
    # difference between plotted curves.
    loss_wrapper = LossWrapper(nn.NLLLoss(), extra_metrics=[nn.NLLLoss()])
    criterion = loss_wrapper

    lr_finder = LRFinder(model, optimizer, criterion, device='cuda')
    # NOTE: Here we set `smooth_f` to 0 so that those 2 loss curves can be
    # visually distinguished after showing the figure.
    lr_finder.range_test(
        trainloader, end_lr=0.1, num_iter=100, step_mode='exp', smooth_f=0
    )

    fig, ax = plt.subplots()
    # NOTE: Set `skip_start` and `skip_end` to 0 because we want to plot the whole series of loss
    lr_finder.plot(ax=ax, skip_start=0, skip_end=0, suggest_lr=False)

    for i, (metric_name, metric_log) in enumerate(loss_wrapper.extra_metrics_log.items()):
        np_metric_log = torch.stack(metric_log).cpu().numpy()
        ax.plot(lr_finder.history['lr'], np_metric_log, linewidth=2+i, zorder=1+i)

    plt.show()

if __name__ == '__main__':
    main()

And here is the result I got: lr_finder_plot_multiple_metrics

If you need to implement something requiring more complicate control, you might need to create a subclass inheriting from LRFinder and override _train_batch() to handle the process for calculating loss and metrics. See also this part:

https://github.com/davidtvs/pytorch-lr-finder/blob/9cfcbecba35866711647a251b4527c6f29b9c9f5/torch_lr_finder/lr_finder.py#L377-L378

Feel free to let me know if you still have problems!

nagapavan525 commented 4 years ago

Hi @NaleRaphael NaleRaphael,

Thank you for your suggestion. Looks like creating a wrapper for loss function suffices my need.

Regards, Naga Pavan