jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.03k stars 2.75k forks source link

Wrong result when Jitting a function #23866

Closed mtagliazucchi closed 1 hour ago

mtagliazucchi commented 1 day ago

Description

Hi everyone,

I noticed a very strange behaviour of a function when Jitted.

Here's the code:

import jax
import jax.numpy as jnp

def get_neff(weights, mu, Ndraw=None):

    nsamples = weights.shape[-1]
    if Ndraw is None:
        Ndraw = nsamples  
    s2   = jnp.sum(weights**2, axis=-1) / Ndraw**2
    sig2 = s2 - mu**2 / Ndraw

    return jax.lax.cond(sig2 == 0.0, lambda _ : 0.0, lambda _ : mu**2 / sig2, operand=None)

weights = jnp.ones(5000)
mu       = jnp.mean(weights)

print(get_neff(weights, mu))             # 0.0
print(jax.jit(get_neff)(weights, mu))  # -1.8519084246547628e+20

The non Jitted function prints the correct result (0.), while the Jitted function seems to return the second function of lax.cond. The problem seems solved if I change the boolean condition of lax.cond into sig2 <= 0.0, but I'm not sure this condition holds in general.

Can someone explain why this occur and how to solve this problem? However, I need to use the function get_neff within another Jitted function.

System info (python version, jaxlib version, accelerator, etc.)

System info:

jax:    0.3.25
jaxlib: 0.3.25
numpy:  1.26.3
python: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
superbobry commented 1 day ago

jit doesn't guarantee that the intermediates and the outputs will be exactly equal to the ones without jit. I think you might need to use jnp.abs(sig2) <= eps to be robust to that.

mtagliazucchi commented 1 hour ago

jit doesn't guarantee that the intermediates and the outputs will be exactly equal to the ones without jit. I think you might need to use jnp.abs(sig2) <= eps to be robust to that.

Thanks! This is helpful