eelregit / pmwd

Differentiable Cosmological Forward Model
BSD 3-Clause "New" or "Revised" License
70 stars 19 forks source link

nbody vjp broken at jitting time #34

Closed anvvalade closed 3 days ago

anvvalade commented 4 weeks ago

Cannot evaluate jax.vjp on the jitted version of nbody, getting

ValueError: invalid literal for int() with base 10: 'int16[64,3]'

Here is a minimal example to reproduce the error:

import jax
import pmwd

conf = pmwd.Configuration(
    ptcl_spacing=1,
    ptcl_grid_shape=(4,) * 3,
    mesh_shape=2,
    a_start=1 / 20,
    a_nbody_maxstep=(1 - 1 / 20) / 5,
    cosmo_dtype="float32",
    float_dtype="float32",
    pmid_dtype="int16",
)

cosmo = pmwd.SimpleLCDM(conf)
cosmo = pmwd.boltzmann(cosmo, conf)

modes = pmwd.linear_modes(
    pmwd.white_noise(0, conf),
    cosmo,
    conf,
    None,
    False,
)

def run(modes, cosmo, conf):
    return pmwd.nbody(*pmwd.lpt(modes, cosmo, conf), cosmo, conf)[0]

jitted_run = jax.jit(run)

print("VJP not jitted")
jax.vjp(run, modes, cosmo, conf)
print("VJP jitted")
jax.vjp(jitted_run, modes, cosmo, conf)

Running on python 3.9, with:

pmwd                                  0.1.dev124+gca91e43.d20241028
jax                                   0.4.30
jaxlib                                0.4.30

Full error output:

JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File ~/.conda/envs/py39/bin/ipython:8
      7 sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
----> 8 sys.exit(start_ipython())

File ~/.conda/envs/py39/lib/python3.9/site-packages/IPython/__init__.py:129, in start_ipython()
    128 from IPython.terminal.ipapp import launch_new_instance
--> 129 return launch_new_instance(argv=argv, **kwargs)

File ~/.conda/envs/py39/lib/python3.9/site-packages/traitlets/config/application.py:1077, in launch_instance()
   1076 app.initialize(argv)
-> 1077 app.start()

File ~/.conda/envs/py39/lib/python3.9/site-packages/IPython/terminal/ipapp.py:317, in start()
    316     self.log.debug("Starting IPython's mainloop...")
--> 317     self.shell.mainloop()
    318 else:

File ~/.conda/envs/py39/lib/python3.9/site-packages/IPython/terminal/interactiveshell.py:887, in mainloop()
    886 try:
--> 887     self.interact()
    888     break

File ~/.conda/envs/py39/lib/python3.9/site-packages/IPython/terminal/interactiveshell.py:880, in interact()
    879 if code:
--> 880     self.run_cell(code, store_history=True)

File ~/.conda/envs/py39/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3048, in run_cell()
   3047 try:
-> 3048     result = self._run_cell(
   3049         raw_cell, store_history, silent, shell_futures, cell_id
   3050     )
   3051 finally:

File ~/.conda/envs/py39/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3103, in _run_cell()
   3102 try:
-> 3103     result = runner(coro)
   3104 except BaseException as e:

File ~/.conda/envs/py39/lib/python3.9/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner()
    128 try:
--> 129     coro.send(None)
    130 except StopIteration as exc:

File ~/.conda/envs/py39/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3308, in run_cell_async()
   3305 interactivity = "none" if silent else self.ast_node_interactivity
-> 3308 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3309        interactivity=interactivity, compiler=compiler, result=result)
   3311 self.last_execution_succeeded = not has_raised

File ~/.conda/envs/py39/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3490, in run_ast_nodes()
   3489     asy = compare(code)
-> 3490 if await self.run_code(code, result, async_=asy):
   3491     return True

File ~/.conda/envs/py39/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3550, in run_code()
   3549     else:
-> 3550         exec(code_obj, self.user_global_ns, self.user_ns)
   3551 finally:
   3552     # Reset our crash handler in place

Cell In[4], line 1
----> 1 get_ipython().run_line_magic('run', 'crash_nbody.py')

File ~/.conda/envs/py39/lib/python3.9/site-packages/IPython/core/interactiveshell.py:2456, in run_line_magic()
   2455 with self.builtin_trap:
-> 2456     result = fn(*args, **kwargs)
   2458 # The code below prevents the output from being displayed
   2459 # when using magics with decorator @output_can_be_silenced
   2460 # when the last Python token in the expression is a ';'.

File ~/.conda/envs/py39/lib/python3.9/site-packages/IPython/core/magics/execution.py:849, in run()
    847         else:
    848             # regular execution
--> 849             run()
    851 if 'i' in opts:

File ~/.conda/envs/py39/lib/python3.9/site-packages/IPython/core/magics/execution.py:834, in run()
    833 def run():
--> 834     runner(filename, prog_ns, prog_ns,
    835             exit_ignore=exit_ignore)

File ~/.conda/envs/py39/lib/python3.9/site-packages/IPython/core/interactiveshell.py:2905, in safe_execfile()
   2904     glob, loc = (where + (None, ))[:2]
-> 2905     py3compat.execfile(
   2906         fname, glob, loc,
   2907         self.compile if shell_futures else None)
   2908 except SystemExit as status:
   2909     # If the call was made with 0 or None exit status (sys.exit(0)
   2910     # or sys.exit() ), don't bother showing a traceback, as both of
   (...)
   2916     # For other exit status, we show the exception unless
   2917     # explicitly silenced, but only in short form.

File ~/.conda/envs/py39/lib/python3.9/site-packages/IPython/utils/py3compat.py:55, in execfile()
     54 compiler = compiler or compile
---> 55 exec(compiler(f.read(), fname, "exec"), glob, loc)

File ~/work/code/scripts/tests/crash_nbody.py:38
     37 print("VJP jitted")
---> 38 jax.vjp(jitted_run, modes, cosmo, conf)

File ~/work/code/scripts/tests/crash_nbody.py:29, in run()
     28 def run(modes, cosmo, conf):
---> 29     return pmwd.nbody(*pmwd.lpt(modes, cosmo, conf), cosmo, conf)[0]

JaxStackTraceBeforeTransformation: ValueError: invalid literal for int() with base 10: 'int16[64,3]'

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
File ~/work/code/scripts/tests/crash_nbody.py:38
     36 jax.vjp(run, modes, cosmo, conf)
     37 print("VJP jitted")
---> 38 jax.vjp(jitted_run, modes, cosmo, conf)

File ~/.conda/envs/py39/lib/python3.9/site-packages/jax/_src/api.py:2167, in vjp(***failed resolving arguments***)
   2165 del reduce_axes
   2166 check_callable(fun)
-> 2167 return _vjp(
   2168     lu.wrap_init(fun), *primals, has_aux=has_aux)

File ~/.conda/envs/py39/lib/python3.9/site-packages/jax/_src/api.py:2176, in _vjp(fun, has_aux, *primals)
   2174 if not has_aux:
   2175   flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 2176   out_primals, vjp = ad.vjp(flat_fun, primals_flat)
   2177   out_tree = out_tree()
   2178 else:

File ~/.conda/envs/py39/lib/python3.9/site-packages/jax/_src/interpreters/ad.py:143, in vjp(traceable, primals, has_aux)
    141 def vjp(traceable, primals, has_aux=False):
    142   if not has_aux:
--> 143     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    144   else:
    145     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)

File ~/.conda/envs/py39/lib/python3.9/site-packages/jax/_src/interpreters/ad.py:132, in linearize(traceable, *primals, **kwargs)
    130 _, in_tree = tree_flatten(((primals, primals), {}))
    131 jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 132 jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
    133 out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
    134 assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)

File ~/.conda/envs/py39/lib/python3.9/site-packages/jax/_src/profiler.py:335, in annotate_function.<locals>.wrapper(*args, **kwargs)
    332 @wraps(func)
    333 def wrapper(*args, **kwargs):
    334   with TraceAnnotation(name, **decorator_kwargs):
--> 335     return func(*args, **kwargs)
    336   return wrapper

File ~/.conda/envs/py39/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py:777, in trace_to_jaxpr_nounits(fun, pvals, instantiate)
    775 with core.new_main(JaxprTrace, name_stack=current_name_stack) as main:
    776   fun = trace_to_subjaxpr_nounits(fun, main, instantiate)
--> 777   jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    778   assert not env
    779   del main, fun, env

File ~/.conda/envs/py39/lib/python3.9/site-packages/jax/_src/linear_util.py:192, in WrappedFun.call_wrapped(self, *args, **kwargs)
    189 gen = gen_static_args = out_store = None
    191 try:
--> 192   ans = self.f(*args, **dict(self.params, **kwargs))
    193 except:
    194   # Some transformations yield from inside context managers, so we have to
    195   # interrupt them before reraising the exception. Otherwise they will only
    196   # get garbage-collected at some later time, running their cleanup tasks
    197   # only after this exception is handled, which can corrupt the global
    198   # state.
    199   while stack:

    [... skipping hidden 25 frame]

File ~/.conda/envs/py39/lib/python3.9/site-packages/pmwd/tree_util.py:100, in pytree_dataclass.<locals>.tree_unflatten(aux_data, children)
     98 def tree_unflatten(aux_data, children):
     99     print("TREE UNFLATTEN\n", f"{aux_data=}", "\n", f"{children=}")
--> 100     return cls(
    101         **dict(zip(children_names, children)), **dict(zip(aux_data_names, aux_data))
    102     )

File <string>:9, in __init__(self, conf, pmid, disp, vel, acc, attr)

File ~/.conda/envs/py39/lib/python3.9/site-packages/pmwd/particles.py:75, in Particles.__post_init__(self)
     70     value = tree_map(lambda x: jnp.asarray(x, dtype=dtype), value)
     71 else:
     72     value = (
     73         value
     74         if value is None or is_float0_array(value)
---> 75         else jnp.asarray(value, dtype=dtype)
     76     )
     77 object.__setattr__(self, name, value)

File ~/.conda/envs/py39/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:3289, in asarray(a, dtype, order, copy)
   3287 if dtype is not None:
   3288   dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)  # type: ignore[assignment]
-> 3289 return array(a, dtype=dtype, copy=bool(copy), order=order)

File ~/.conda/envs/py39/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:3198, in array(object, dtype, copy, order, ndmin)
   3191 out: ArrayLike
   3193 if all(not isinstance(leaf, Array) for leaf in leaves):
   3194   # TODO(jakevdp): falling back to numpy here fails to overflow for lists
   3195   # containing large integers; see discussion in
   3196   # https://github.com/google/jax/pull/6047. More correct would be to call
   3197   # coerce_to_array on each leaf, but this may have performance implications.
-> 3198   out = np.asarray(object, dtype=dtype)
   3199 elif isinstance(object, Array):
   3200   assert object.aval is not None

ValueError: invalid literal for int() with base 10: 'int16[64,3]'

Some printing from Particles.__post__init__ and tree_util.unflatten, last loop before crashing and the loop that crashes:

TREE UNFLATTEN
 aux_data=(Configuration(ptcl_spacing=1, ptcl_grid_shape=(4, 4, 4), mesh_shape=(8, 8, 8), cosmo_dtype=dtype('float32'), pmid_dtype=dtype('int16'), float_dtype=dtype('float32'), k_pivot_Mpc=0.05, T_cmb=2.7255, M=1.98847e+40, L=3.0856775815e+22, T=3.0856775815e+17, transfer_fit=True, transfer_fit_nowiggle=False, transfer_lgk_min=-4, transfer_lgk_max=3, transfer_lgk_maxstep=0.0078125, growth_rtol=0.00034526698300124393, growth_atol=0.00034526698300124393, growth_inistep=(1, None), lpt_order=2, a_start=0.05, a_stop=1, a_lpt_maxstep=0.0078125, a_nbody_maxstep=0.19, symp_splits=((0, 0.5), (1, 0.5)), chunk_size=16777216),)
 children=(Traced<ShapedArray(int16[64,3])>with<DynamicJaxprTrace(level=3/0)>, Traced<ShapedArray(float32[64,3])>with<DynamicJaxprTrace(level=3/0)>, Traced<ShapedArray(float32[64,3])>with<DynamicJaxprTrace(level=3/0)>, Traced<ShapedArray(float32[64,3])>with<DynamicJaxprTrace(level=3/0)>, None)
__post__init__
 name='pmid'
 value=Traced<ShapedArray(int16[64,3])>with<DynamicJaxprTrace(level=3/0)>
 dtype=dtype('int16')
__post__init__
 name='disp'
 value=Traced<ShapedArray(float32[64,3])>with<DynamicJaxprTrace(level=3/0)>
 dtype=dtype('float32')
__post__init__
 name='vel'
 value=Traced<ShapedArray(float32[64,3])>with<DynamicJaxprTrace(level=3/0)>
 dtype=dtype('float32')
__post__init__
 name='acc'
 value=Traced<ShapedArray(float32[64,3])>with<DynamicJaxprTrace(level=3/0)>
 dtype=dtype('float32')
__post__init__
 name='attr'
 value=None
 dtype=dtype('float32')
TREE UNFLATTEN
 aux_data=(Configuration(ptcl_spacing=1, ptcl_grid_shape=(4, 4, 4), mesh_shape=(8, 8, 8), cosmo_dtype=dtype('float32'), pmid_dtype=dtype('int16'), float_dtype=dtype('float32'), k_pivot_Mpc=0.05, T_cmb=2.7255, M=1.98847e+40, L=3.0856775815e+22, T=3.0856775815e+17, transfer_fit=True, transfer_fit_nowiggle=False, transfer_lgk_min=-4, transfer_lgk_max=3, transfer_lgk_maxstep=0.0078125, growth_rtol=0.00034526698300124393, growth_atol=0.00034526698300124393, growth_inistep=(1, None), lpt_order=2, a_start=0.05, a_stop=1, a_lpt_maxstep=0.0078125, a_nbody_maxstep=0.19, symp_splits=((0, 0.5), (1, 0.5)), chunk_size=16777216),)
 children=('int16[64,3]', 'float32[64,3]', 'float32[64,3]', 'float32[64,3]', None)
__post__init__
 name='pmid'
 value='int16[64,3]'                   # SHOULD BE ARRAY
 dtype=dtype('int16')
eelregit commented 4 weeks ago

Interesting. @modichirag told me about the same error message on slack. Thanks for the example. I wonder if this is another problem that comes and goes with JAX updates. I am busy with job stuff right now, and will take a look, in 2 weeks at the latest.

anvvalade commented 4 weeks ago

Thanks for the quick answer. May your job stuff go smoothly, so I don't spend too long in front of my computer frantically pressing F5 ! ;)

anvvalade commented 1 week ago

@eelregit, have you had a chance to look at the issue? I have deadlines coming up and would love to present work based on your code (it'd be a great advertisement for us all ;) )

eelregit commented 1 week ago

Ok, I should have taken a closer look 2 weeks ago. Two things:

  1. current master uses a for loop for nbody time integration with manually jitted components, which means the benefit of jit is small and you can probably do without it
  2. the order is usually jit(grad()) instead of the other way

So could you simply try removing the jit in your current blackjax pipeline?

anvvalade commented 6 days ago

No problem, you couldn't know!

  1. this could be very easily replaces by a lax.scan function, right? Would that solve the issue?
  2. Indeed that seems to work

I was trying to use tensorflow-probability.mcmc.nuts but it seems that using the NUTS of blackjax solves the issue. I'll try further.

As to removing the jit, it seems like a deal breaker on my end, I need the jitting inside and outside of the posterior pdf. Too long without. I could try to jit parts manually but it'd probably imply tempering with the Monte Carlo libraries if I want the integration to be jitted too.... Bad idea!

Do I understand correctly that beside these work-around, you do not know exactly what causes the bug?

eelregit commented 6 days ago

Will probably use while from lax in the future, scan won't work for some features.

I don't think the problem is from pmwd, because 'int16[64,3]' is a jax repr. I think they didn't expect the grad(jit()) order and don't have a test for it. In general I cannot think of a use case like that, so maybe everything is fine now, right?

eelregit commented 6 days ago

BTW I'm surprised that HMC computations themselves are enough to dominate without jitting. They shouldn't be that heavy, right?

anvvalade commented 3 days ago

Sorry for the delay -- needed to run some tests before coming back to you. Turns out you were right (thanks) the issue was not with TensorFlow, but that I had a jit decorator somewhere that I just needed to remove to avoid the crash. Thanks a lot for that. I am now running in other, funnier crashes.

You are also probably right that the HMC computation themselves are not too heavy, especially when nbody is implied. But it'd mean I'd also have to manually jit every other function of the code, and there are quite some!

I'll mark this issue as solved! Thanks.

eelregit commented 3 days ago

Thanks for letting me know. Do you still have crashes if you do jit(grad(posterior(...)))?

And happy to take a look if you can elaborate (or share code about) which jit caused the problem and what the funnier crashes are.

BTW, there was a Omega_m gradient problem that came last year with a JAX update and went away with another sometime this year. I never figured out why but you can use the latest version of JAX to avoid that.