google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
939 stars 66 forks source link

support for scalar variables #480

Open mblondel opened 1 year ago

mblondel commented 1 year ago

It seems like some solvers in JAXopt don't work properly with scalar variables. Since scalars are valid pytrees, we need to ensure that all solvers in JAXopt work properly with scalar variables, and write a common test.