Closed tare closed 8 months ago
Hello @tare, Thanks for pointing this out. vmap evaluates both branches of a cond (not the case without vmap), see https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.cond.html. I'm not sure how we could then have failure diagnostics under vmap. At least I have patched #544 for zoom not to display failure diagnostics unless verbose is set to True. That will avoid unncessary prints.
Thanks for the quick reply and pointing out https://github.com/google/jaxopt/pull/544! I hope that PR gets merged soon.
Closing as #544 has been merged.
Environment
Description
ZoomLineSearch
undervmap
ends up callingfailure_diagnostic()
even whensafe_stepsize > 0.
as shown here. This can result in a lot of print outs, and I didn't see a way to disable the failure diagnostic print outs given the current implementation. I think the relevant commit is https://github.com/google/jaxopt/commit/614dc7bf769628eee6f72e636cb608c0f6678596. Below, you will find minimum reproducible examples.The following code
gives the following warnings
Whereas, the following code does not produce any warnings
Here is a minimal reproducible example illustrating the issue with
jax.debug.print
,cond
, andvmap
; the following codegives the following output