davidtvs / pytorch-lr-finder

A learning rate range test implementation in PyTorch
MIT License
921 stars 120 forks source link

Prevent unexpected unpacking error when calling `lr_finder.plot()` with `suggest_lr=True` #98

Closed NaleRaphael closed 2 months ago

NaleRaphael commented 3 months ago

@davidtvs, this PR should fix #88. And if it is resolved, maybe we can go on #97 to make a new release to PyPI. @chAwater, I've rewritten part of the test case you made, please feel free to advise me if there is anything can be improved.

Problem

When calling lr_finder.plot(..., suggest_lr=True), we usually expect the returned value is actually a tuple containing ax and suggested_lr, and the API is also described as it.

But it would return only ax if there is no sufficient data points to calculate gradient of lr-loss curve: https://github.com/davidtvs/pytorch-lr-finder/blob/fd9e949bc709e31881c77a4b0710745524c5091e/torch_lr_finder/lr_finder.py#L536-L542 https://github.com/davidtvs/pytorch-lr-finder/blob/fd9e949bc709e31881c77a4b0710745524c5091e/torch_lr_finder/lr_finder.py#L568-L571

Therefore, if users prefer unpacking the returned value directly as below, they would sometimes ran into the error TypeError: cannot unpack non-iterable AxesSubplot object.

ax, suggested_lr = lr_finder.plot(ax=ax, suggest_lr=True)

Solution

Always return 2 values if suggest_lr=True. This makes sure it work with the 2 kinds of syntax as follows:

# 1. unpack returned value directly
ax, suggested_lr = lr_finder.plot(ax=ax, suggest_lr=True)

# 2. use a single variable to catch the returned value, then unpack them manually (user can check it before unpacking)
retval = lr_finder.plot(ax=ax, suggest_lr=True)
assert isinstance(retval, tuple) and len(retval) == 2
ax, suggested_lr = retval

The responsibility of "check whether suggested_lr is available/none" is left back to users now. But it should be fine since the warning message would show up and it's easy to check. Also, the warning message is now more verbose to help user figure out the problem.

Note

Though I think this issue could be resolved better by separating the feature of "suggest learning rate" into a new function, it should be safer to keep the API unchanged before we decide to support more different methods to find a suggested learning rate in the future.