Open patrickpei opened 3 years ago
The current implementation is just a wrapper around the numba-based code, so unfortunately it can't be used together with jit...
There is on-going work on a Numba / JAX bridge, which would enable to use Numba code from within a jitted JAX function but it may take some time to land in JAX master.
CC @josipd
Sorry to jump in but that's actually easier than you seem to think. See the below code as an example
from functools import partial
import numba as nb
from jax import jit, ShapeDtypeStruct
from jax.experimental import host_callback
@nb.jit
def some_numba_stuff(x):
return x
@partial(jit, backend="gpu")
def some_jax_stuff(x):
y = host_callback.call(some_numba_stuff, x, result_shape=ShapeDtypeStruct(x.shape, x.dtype))
z = 2 * y
return y
print(some_jax_stuff(5.))
Because host_callback is still experimental I don't expect @mblondel would want it to live in his code, but depending on what @patrickpei is up to, that probably would do the trick. But then I guess you would have to work with your own fork and modify the numpy ops wrapper inplace. Decisions decisions :)
Or... you just replace the numba implementation of isotonic regression with a jax one. Can't use PAV, solutions are not absolutely perfect, but it works
import jax
import jax.numpy as jnp
import jaxopt as jo
from jaxopt import ProjectedGradient
from jaxopt.projection import projection_non_negative
def projection_non_negative_after0(x: jnp.array, hyperparams=None) -> jnp.array:
return x.at[1:].set(jnp.where(x[1:] < 0.0, 0.0, x[1:]))
def constrain_param(param):
return param.at[1:].set(param[0] - jnp.cumsum(param[1:]))
def isotonic_opt_l2(y: jnp.array) -> jnp.array:
"""Solves an isotonic regression problem with L2 loss using projected gradient descent.
Formally, it solves argmin_{v_1 >= ... >= v_n} 0.5 ||v - y||^2.
Args:
y: input to isotonic regression, a 1d-array.
"""
def loss_fn(param):
return jnp.sum((y - constrain_param(param)) ** 2) / 2
solver = jo.ProjectedGradient(
fun=loss_fn, maxiter=10000, projection=projection_non_negative_after0
)
sol = solver.run(y.at[1:].set(0.0)).params
return sol.at[1:].set(y[0] - jnp.cumsum(sol[1:]))
def isotonic_opt_kl(y: jnp.array, w: jnp.array) -> jnp.array:
"""Solves an isotonic regression problem with KL divergence using projected gradient descent.
Formally, it solves argmin_{v_1 >= ... >= v_n} <e^{y-v}, 1> + <e^w, v>.
Args:
y: input to isotonic optimization, a 1d-array.
w: input to isotonic optimization, a 1d-array.
"""
def loss_fn(param):
constr = constrain_param(param)
return jnp.sum(jnp.exp(y - constr)) + jnp.sum(jnp.exp(w) * constr)
solver = jo.ProjectedGradient(
fun=loss_fn, maxiter=10000, projection=projection_non_negative_after0
)
sol = solver.run(y.at[1:].set(0.0)).params
return sol.at[1:].set(y[0] - jnp.cumsum(sol[1:]))
y = jnp.array([10.0, 7.0, 6.5, 6.8, 7.0, 8.0, 9.0])
print(isotonic_opt_l2(y))
print(isotonic_opt_kl(y, jnp.ones_like(y)))
JIT this without a problem. I'm not sure about correctness of the KL solution...
Have you solved this? It just does not work within Jax !
Thanks for the work on this! Here's an example: