google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
903 stars 62 forks source link

Disable warnings in vmap or print to stderr #589

Open joeryjoery opened 2 months ago

joeryjoery commented 2 months ago

When running a solver like BFGS or similar, the code runs fine without warnings if I call it directly. But, if I call it inside a vmap then I get spammed by warnings that are not too problematic for my use-case.

I also cannot filter the warnings directly, since the warnings do not print to stderr but to stdout... (at least I think so). I wasn't able to filter the warnings with:

warnings.filterwarnings('ignore', '.*jaxopt.ZoomLineSearch.*')

This is super annoying when logging other information to stdout in my tests, how can I get rid of this?

WARNING: jaxopt.ZoomLineSearch: Very large absolute slope at stepsize=0. (|slope|=0.06082260608673096). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: jaxopt.ZoomLineSearch: Very large absolute slope at stepsize=0. (|slope|=0.062071025371551514). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: jaxopt.ZoomLineSearch: Very large absolute slope at stepsize=0. (|slope|=0.05304361507296562). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: jaxopt.ZoomLineSearch: Cannot even make a step without getting Inf or Nan. The linesearch won't make a step and the optimizer is stuck.
WARNING: jaxopt.ZoomLineSearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
...
joeryjoery commented 2 months ago

fyi, I found 1 workaround to wrap any call to jaxopt in a context-manager that disables print-statements altogether.

joeryjoery commented 2 months ago

I just found out that the github version fixed this.

But the version on PyPI does not have the fix for this yet...