Open mblondel opened 1 year ago
It seems like some solvers in JAXopt don't work properly with scalar variables. Since scalars are valid pytrees, we need to ensure that all solvers in JAXopt work properly with scalar variables, and write a common test.
It seems like some solvers in JAXopt don't work properly with scalar variables. Since scalars are valid pytrees, we need to ensure that all solvers in JAXopt work properly with scalar variables, and write a common test.