google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
932 stars 65 forks source link

Attempted boolean conversion of traced array - for hager-zhang #557

Open SNMS95 opened 11 months ago

SNMS95 commented 11 months ago

I was trying to use jaxopt.LBFGS with linesearch='hager-zhang'. The following error appeared

Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function _update_interval at python3.10/site-packages/jaxopt/_src/hager_zhang_linesearch.py:286 for cond. This value became a tracer due to JAX operations on these lines:

  operation a:f64[] = convert_element_type[new_dtype=float64 weak_type=False] b
SNMS95 commented 11 months ago

Hi @mblondel ,

Do you know why this could happen?

mblondel commented 11 months ago

Do you have a short script to reproduce?

CC @emilyfertig