Closed mirkobunse closed 1 month ago
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
Hello @mirkobunse ,
Thanks for the fix! There is a larger issue in optax: any optimizer with mixed precision (different precision with gradients and parameters) may compile twice. A fix would probably be to change the signature of the init functions rather than adding arguments to the definition of the optimizer. We will discuss that internally soon and I'll keep you posted.
If this is blocking your research you may also consider optimistix which carefully took care of some recompilation issues
Problems
I experienced two problems that the zoom_linesearch exhibits when used within a JIT-compiled training step:
1) the training step is compiled twice because some dtypes of the linesearch's state change after the first iteration. 2) the compilation errors for dtypes other than float32.
I want to give a minimal example to reproduce both problems and I want to provide a work-around for the first one. The second problem cannot be worked around; hence, I propose this PR to ultimately solve both problems.
Minimal working example for the re-compilation issue
I sample some random data and create an LBGFS optimizer with zoom linesearch. The training step is JIT-compiled and a
print
statement informs us about the triggering of a compilation (because this statement is silent in the compiled variant of the function).The output of this script is the following. We recognize that the function is compiled twice and that, between iterations, dtypes have changed from weak types to strong types.
Workaround for the re-compilation issue
We can change the initial state to have strong dtypes right from the beginning:
As desired, the training step is now compiled only once.
Minimal working example for non-default dtypes
In the above example, we make two slight changes,
and
Unfortunately, these changes break the compilation entirely. Since this error, shown in the following, appears in the internals of the zoom linesearch, I do not see an easy work-around for this problem.
Solution: this PR
I suggest the following improvements:
The specification of a value dtype works as follows: