Open BalzaniEdoardo opened 4 months ago
First of all, thanks for the very nice package.
In this PR I am generalizing the behavior of jaxopt.prox.prox_lasso to handle scalar input when jit-compiled.
jaxopt.prox.prox_lasso
In particular the following used to result in an exception.
import numpy as np from jaxopt import prox import jax rng = np.random.RandomState(0) x = (rng.rand(20) * 2 - 1, rng.rand(20) * 2 - 1) jax.jit(prox.prox_lasso)(x, 0.5)
Best regards, Edoardo
First of all, thanks for the very nice package.
In this PR I am generalizing the behavior of
jaxopt.prox.prox_lasso
to handle scalar input when jit-compiled.In particular the following used to result in an exception.
Best regards, Edoardo