google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
936 stars 66 forks source link

fix lasso with scalar l1reg #604

Open BalzaniEdoardo opened 4 months ago

BalzaniEdoardo commented 4 months ago

First of all, thanks for the very nice package.

In this PR I am generalizing the behavior of jaxopt.prox.prox_lasso to handle scalar input when jit-compiled.

In particular the following used to result in an exception.

import numpy as np
from jaxopt import prox
import jax

rng = np.random.RandomState(0)
x = (rng.rand(20) * 2 - 1, rng.rand(20) * 2 - 1)
jax.jit(prox.prox_lasso)(x, 0.5)

Best regards, Edoardo