google-research / fast-soft-sort

Fast Differentiable Sorting and Ranking
Apache License 2.0
562 stars 45 forks source link

Unable to jit jax ops #9

Open patrickpei opened 3 years ago

patrickpei commented 3 years ago

Thanks for the work on this! Here's an example:

import jax.numpy as jnp
from fast_soft_sort.jax_ops import soft_rank
from jax import grad, jit

@jit
def f1(x):
    x = x.reshape(1, 3)
    y = soft_rank(x)[0]
    return y.mean()

x = jnp.array([1.0, 2.0, 3.0])
f2 = grad(f1)
f2(x)
Traceback (most recent call last):
  File "test.py", line 15, in <module>
    f2(x)
  File "test.py", line 9, in f1
    y = soft_rank(x)[0]
  File "/home/patrick/.pyenv/versions/3.8.5/lib/python3.8/site-packages/fast_soft_sort/jax_ops.py", line 80, in soft_rank
    return jnp.vstack([func(val) for val in values])
  File "/home/patrick/.pyenv/versions/3.8.5/lib/python3.8/site-packages/fast_soft_sort/jax_ops.py", line 80, in <listcomp>
    return jnp.vstack([func(val) for val in values])
  File "/home/patrick/.pyenv/versions/3.8.5/lib/python3.8/site-packages/fast_soft_sort/jax_ops.py", line 35, in _func_fwd
    values = np.array(values)
jax._src.traceback_util.FilteredStackTrace: Exception: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[3])>wit$
<DynamicJaxprTrace(level=0/1)>.

This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using `jnp` toge$
her with `import jax.numpy as jnp` rather than using `np` via `import numpy as np`. If this error arises on a line that involves array indexing, like `x[idx]`, it may be tha$
 the array being indexed `x` is a raw numpy.ndarray while the indices `idx` are a JAX Tracer instance; in that case, you can instead write `jax.device_put(x)[idx]`.
mblondel commented 3 years ago

The current implementation is just a wrapper around the numba-based code, so unfortunately it can't be used together with jit...

There is on-going work on a Numba / JAX bridge, which would enable to use Numba code from within a jitted JAX function but it may take some time to land in JAX master.

CC @josipd

AdrienCorenflos commented 3 years ago

Sorry to jump in but that's actually easier than you seem to think. See the below code as an example

from functools import partial
import numba as nb

from jax import jit, ShapeDtypeStruct
from jax.experimental import host_callback

@nb.jit
def some_numba_stuff(x):
    return x

@partial(jit, backend="gpu")
def some_jax_stuff(x):
    y = host_callback.call(some_numba_stuff, x, result_shape=ShapeDtypeStruct(x.shape, x.dtype))
    z = 2 * y
    return y

print(some_jax_stuff(5.))

Because host_callback is still experimental I don't expect @mblondel would want it to live in his code, but depending on what @patrickpei is up to, that probably would do the trick. But then I guess you would have to work with your own fork and modify the numpy ops wrapper inplace. Decisions decisions :)

ingmarschuster commented 1 year ago

Or... you just replace the numba implementation of isotonic regression with a jax one. Can't use PAV, solutions are not absolutely perfect, but it works

import jax
import jax.numpy as jnp
import jaxopt as jo
from jaxopt import ProjectedGradient
from jaxopt.projection import projection_non_negative

def projection_non_negative_after0(x: jnp.array, hyperparams=None) -> jnp.array:
    return x.at[1:].set(jnp.where(x[1:] < 0.0, 0.0, x[1:]))

def constrain_param(param):
    return param.at[1:].set(param[0] - jnp.cumsum(param[1:]))

def isotonic_opt_l2(y: jnp.array) -> jnp.array:
    """Solves an isotonic regression problem with L2 loss using projected gradient descent.

    Formally, it solves argmin_{v_1 >= ... >= v_n} 0.5 ||v - y||^2.

    Args:
      y: input to isotonic regression, a 1d-array.
    """

    def loss_fn(param):
        return jnp.sum((y - constrain_param(param)) ** 2) / 2

    solver = jo.ProjectedGradient(
        fun=loss_fn, maxiter=10000, projection=projection_non_negative_after0
    )
    sol = solver.run(y.at[1:].set(0.0)).params
    return sol.at[1:].set(y[0] - jnp.cumsum(sol[1:]))

def isotonic_opt_kl(y: jnp.array, w: jnp.array) -> jnp.array:
    """Solves an isotonic regression problem with KL divergence using projected gradient descent.

      Formally, it solves argmin_{v_1 >= ... >= v_n} <e^{y-v}, 1> + <e^w, v>.

    Args:
      y: input to isotonic optimization, a 1d-array.
      w: input to isotonic optimization, a 1d-array.
    """

    def loss_fn(param):
        constr = constrain_param(param)
        return jnp.sum(jnp.exp(y - constr)) + jnp.sum(jnp.exp(w) * constr)

    solver = jo.ProjectedGradient(
        fun=loss_fn, maxiter=10000, projection=projection_non_negative_after0
    )
    sol = solver.run(y.at[1:].set(0.0)).params
    return sol.at[1:].set(y[0] - jnp.cumsum(sol[1:]))

y = jnp.array([10.0, 7.0, 6.5, 6.8, 7.0, 8.0, 9.0])
print(isotonic_opt_l2(y))
print(isotonic_opt_kl(y, jnp.ones_like(y)))

JIT this without a problem. I'm not sure about correctness of the KL solution...

marianogabitto commented 10 months ago

Have you solved this? It just does not work within Jax !