eelregit / pmwd

Differentiable Cosmological Forward Model
BSD 3-Clause "New" or "Revised" License
69 stars 16 forks source link

ValueError when using @jax.jit decorator for objective function with pmwd #28

Closed pointeee closed 5 months ago

pointeee commented 6 months ago

Hi, I am writing a sampling code using pmwd. The objective function in my code (actually negative log probability) is frequently called, so I attempted to add @jax.jit to speed it up. However, I always got the an error like ValueError: invalid literal for int() with base 10: 'int16[8,3]' when executing the nbody function. Here is an illustrative example:

from pmwd import (
    Configuration, Cosmology,
    SimpleLCDM,
    boltzmann,
    white_noise, linear_modes,
    lpt,
    nbody,
    scatter,
) 

import jax
import jax.numpy as jnp
from jax import jit

def model(modes, cosmo, conf):
    modes = linear_modes(modes, cosmo, conf, None, False)
    ptcl, obsvbl = lpt(modes, cosmo, conf)
    ptcl, obsvbl = nbody(ptcl, obsvbl, cosmo, conf)
    dens = scatter(ptcl, conf)
    return dens

#@jit #uncomment this, the error occurs!
def obj(modes, cosmo, conf):
    dens = model(modes, cosmo, conf)
    return jnp.sum(dens)

conf = Configuration(ptcl_spacing=4, ptcl_grid_shape=(2,)*3, lpt_order=1, float_dtype=jnp.float64, \
                     a_start=0.01, a_nbody_maxstep=1)

p0 = jnp.array([0.3, 0.8])
cosmo = Cosmology.from_sigma8(conf, Omega_m=0.3, sigma8=0.8, n_s=0.96, Omega_b=0.05, h=0.7)
cosmo = boltzmann(cosmo, conf)

obj_func = lambda z: obj(z, cosmo, conf)

vng = jax.value_and_grad(obj_func, argnums=(0))
data_ = jnp.array([[[1,1], [1,1]], [[1, 1], [1, 1]]], dtype=jnp.float64)

vng(data_)

And I got the following error for the script above when enabling @jax.jit:

(jax-test) [user@cluster ~]$ python jit_issue.py 
Traceback (most recent call last):
  File "/home/user/jit_issue.py", line 39, in <module>
    vng(data_)
  File "/home/user/jit_issue.py", line 34, in <lambda>
    obj_func = lambda z: obj(z, cosmo, conf)
  File "/home/user/jit_issue.py", line 24, in obj
    dens = model(modes, cosmo, conf)
  File "/home/user/jit_issue.py", line 18, in model
    ptcl, obsvbl = nbody(ptcl, obsvbl, cosmo, conf)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: invalid literal for int() with base 10: 'int16[8,3]'

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/user/jit_issue.py", line 39, in <module>
    vng(data_)
  File "/home/user/jit_issue.py", line 34, in <lambda>
    obj_func = lambda z: obj(z, cosmo, conf)
  File "/home/user/work/pmwd/pmwd/tree_util.py", line 97, in tree_unflatten
    return cls(**dict(zip(children_names, children)),
  File "<string>", line 9, in __init__
  File "/home/user/work/pmwd/pmwd/particles.py", line 72, in __post_init__
    else jnp.asarray(value, dtype=dtype))
  File "/home/user/.conda/envs/jax-test/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 2206, in asarray
    return array(a, dtype=dtype, copy=False, order=order)  # type: ignore
  File "/home/user/.conda/envs/jax-test/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 2152, in array
    out = np.array(object, dtype=dtype, ndmin=ndmin, copy=False)  # type: ignore[arg-type]
ValueError: invalid literal for int() with base 10: 'int16[8,3]'

I am using jax-0.4.23 [cpu] & pmwd-0.1.dev122+g1e1c634.d20240303 (GitHub commit 1e1c634). Is this an expected behavior?


Mar 6th Update: I attempt to run the script above in an environment with jax-gpu installation, and the error persists. I think it is not a CPU-version-only problem.

eelregit commented 5 months ago

Hi Chenze, could you try ptcl_spacing = 1.0 instead and see if the error persists?

pointeee commented 5 months ago

Yes, with ptcl_spacing = 1.0 it reports the same error ValueError: invalid literal for int() with base 10: 'int16[8,3]'

eelregit commented 5 months ago

I don't know why that happens and the error message is vague.

For the current version, you don't need to jit because it's already done for the most part. But this may change in the future.

pointeee commented 5 months ago

Thanks for the reply. I believe it is related to how Particles.__post_init__ in particles.py deals with non-tracing arrays when getting jitted. But since jit is optional here, I think we can close the issue.

eelregit commented 5 months ago

I believe it is related to how Particles.__post_init__ in particles.py deals with non-tracing arrays when getting jitted.

FYI Particles is instantiated in lpt() which is already jitted