jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.08k stars 2.75k forks source link

KeyError: dtype([('float0', 'V')]) raised when passing PRNG keys to odeint function #13369

Open A-Alaa opened 1 year ago

A-Alaa commented 1 year ago

Description

I am trying to pass PRNGKeys to a function, which is integrated by odeint. Here is a simplified example reproducing the problem:

from functools import partial
import jax.numpy as jnp
from jax.tree_util import tree_map
from jax.experimental.ode import odeint
import jax

W = jnp.eye(10, dtype=jnp.float64)
b = jnp.ones((10,), dtype=jnp.float64)
x = jnp.linspace(jnp.zeros((10,)), 5 * jnp.ones((10,)), 10)
t = jnp.linspace(0, 10, 10)

def sample_W(W, prng_key):
    return jax.random.normal(prng_key, (10,)) + W

def f(params, x, t, prng_key):
    W, b = params
    W_s = sample_W(W, prng_key)
    return x @ W_s + b

def loss(params, prng_key, x, t):
    x_hat = odeint(partial(f, params), x[0], t, prng_key)
    return jnp.mean((x_hat - x[-1])**2)

@jax.jit
def update(params, prng_key, x, t, eta=1e-3):
    grads = jax.grad(loss)(params, prng_key, x, t)
    return tree_map(lambda p,g: p - eta * g,  params, grads)

prng_key = jax.random.PRNGKey(0)
params = (W, b)
for i in range(100):
    (prng_key,) = jax.random.split(prng_key, 1)
    params = update(params, prng_key, x, t, 1e-3)

Which raises the following exception:

KeyError: dtype([('float0', 'V')]) ``` --------------------------------------------------------------------------- JaxStackTraceBeforeTransformation Traceback (most recent call last) File ~/GP/env/causal-dev/lib/python3.8/runpy.py:194, in _run_module_as_main(***failed resolving arguments***) 193 sys.argv[0] = mod_spec.origin --> 194 return _run_code(code, main_globals, None, 195 "__main__", mod_spec) File ~/GP/env/causal-dev/lib/python3.8/runpy.py:87, in _run_code(***failed resolving arguments***) 80 run_globals.update(__name__ = mod_name, 81 __file__ = fname, 82 __cached__ = cached, (...) 85 __package__ = pkg_name, 86 __spec__ = mod_spec) ---> 87 exec(code, run_globals) 88 return run_globals File ~/GP/env/causal-dev/lib/python3.8/site-packages/ipykernel_launcher.py:17 15 from ipykernel import kernelapp as app ---> 17 app.launch_new_instance() File ~/GP/env/causal-dev/lib/python3.8/site-packages/traitlets/config/application.py:976, in Application.launch_instance(***failed resolving arguments***) 975 app.initialize(argv) --> 976 app.start() File ~/GP/env/causal-dev/lib/python3.8/site-packages/ipykernel/kernelapp.py:712, in IPKernelApp.start(***failed resolving arguments***) 711 try: --> 712 self.io_loop.start() 713 except KeyboardInterrupt: File ~/GP/env/causal-dev/lib/python3.8/site-packages/tornado/platform/asyncio.py:215, in BaseAsyncIOLoop.start(***failed resolving arguments***) 214 asyncio.set_event_loop(self.asyncio_loop) --> 215 self.asyncio_loop.run_forever() 216 finally: File ~/GP/env/causal-dev/lib/python3.8/asyncio/base_events.py:570, in BaseEventLoop.run_forever(***failed resolving arguments***) 569 while True: --> 570 self._run_once() 571 if self._stopping: File ~/GP/env/causal-dev/lib/python3.8/asyncio/base_events.py:1859, in BaseEventLoop._run_once(***failed resolving arguments***) 1858 else: -> 1859 handle._run() 1860 handle = None File ~/GP/env/causal-dev/lib/python3.8/asyncio/events.py:81, in Handle._run(***failed resolving arguments***) 80 try: ---> 81 self._context.run(self._callback, *self._args) 82 except (SystemExit, KeyboardInterrupt): File ~/GP/env/causal-dev/lib/python3.8/site-packages/ipykernel/kernelbase.py:510, in Kernel.dispatch_queue(***failed resolving arguments***) 509 try: --> 510 await self.process_one() 511 except Exception: File ~/GP/env/causal-dev/lib/python3.8/site-packages/ipykernel/kernelbase.py:499, in Kernel.process_one(***failed resolving arguments***) 498 return None --> 499 await dispatch(*args) File ~/GP/env/causal-dev/lib/python3.8/site-packages/ipykernel/kernelbase.py:406, in Kernel.dispatch_shell(***failed resolving arguments***) 405 if inspect.isawaitable(result): --> 406 await result 407 except Exception: File ~/GP/env/causal-dev/lib/python3.8/site-packages/ipykernel/kernelbase.py:730, in Kernel.execute_request(***failed resolving arguments***) 729 if inspect.isawaitable(reply_content): --> 730 reply_content = await reply_content 732 # Flush output before sending the reply. File ~/GP/env/causal-dev/lib/python3.8/site-packages/ipykernel/ipkernel.py:383, in IPythonKernel.do_execute(***failed resolving arguments***) 382 if with_cell_id: --> 383 res = shell.run_cell( 384 code, 385 store_history=store_history, 386 silent=silent, 387 cell_id=cell_id, 388 ) 389 else: File ~/GP/env/causal-dev/lib/python3.8/site-packages/ipykernel/zmqshell.py:528, in ZMQInteractiveShell.run_cell(***failed resolving arguments***) 527 self._last_traceback = None --> 528 return super().run_cell(*args, **kwargs) File ~/GP/env/causal-dev/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2885, in InteractiveShell.run_cell(***failed resolving arguments***) 2884 try: -> 2885 result = self._run_cell( 2886 raw_cell, store_history, silent, shell_futures, cell_id 2887 ) 2888 finally: File ~/GP/env/causal-dev/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2940, in InteractiveShell._run_cell(***failed resolving arguments***) 2939 try: -> 2940 return runner(coro) 2941 except BaseException as e: File ~/GP/env/causal-dev/lib/python3.8/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner(***failed resolving arguments***) 128 try: --> 129 coro.send(None) 130 except StopIteration as exc: File ~/GP/env/causal-dev/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3139, in InteractiveShell.run_cell_async(***failed resolving arguments***) 3137 interactivity = "none" if silent else self.ast_node_interactivity -> 3139 has_raised = await self.run_ast_nodes(code_ast.body, cell_name, 3140 interactivity=interactivity, compiler=compiler, result=result) 3142 self.last_execution_succeeded = not has_raised File ~/GP/env/causal-dev/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3318, in InteractiveShell.run_ast_nodes(***failed resolving arguments***) 3317 asy = compare(code) -> 3318 if await self.run_code(code, result, async_=asy): 3319 return True File ~/GP/env/causal-dev/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3378, in InteractiveShell.run_code(***failed resolving arguments***) 3377 else: -> 3378 exec(code_obj, self.user_global_ns, self.user_ns) 3379 finally: 3380 # Reset our crash handler in place Cell In [1], line 35 34 (prng_key,) = jax.random.split(prng_key, 1) ---> 35 params = update(params, prng_key, x, t, 1e-3) Cell In [1], line 28, in update(***failed resolving arguments***) 26 @jax.jit 27 def update(params, prng_key, x, t, eta=1e-3): ---> 28 grads = jax.grad(loss)(params, prng_key, x, t) 29 return tree_map(lambda p,g: p - eta * g, params, grads) Cell In [1], line 23, in loss(***failed resolving arguments***) 22 def loss(params, prng_key, x, t): ---> 23 x_hat = odeint(partial(f, params), x[0], t, prng_key) 24 return jnp.mean((x_hat - x[-1])**2) File ~/GP/env/causal-dev/lib/python3.8/site-packages/jax/experimental/ode.py:179, in odeint(***failed resolving arguments***) 178 converted, consts = custom_derivatives.closure_convert(func, y0, t[0], *args) --> 179 return _odeint_wrapper(converted, rtol, atol, mxstep, hmax, y0, t, *args, *consts) File ~/GP/env/causal-dev/lib/python3.8/site-packages/jax/experimental/ode.py:185, in _odeint_wrapper(***failed resolving arguments***) 184 func = ravel_first_arg(func, unravel) --> 185 out = _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args) 186 return jax.vmap(unravel)(out) JaxStackTraceBeforeTransformation: KeyError: dtype([('float0', 'V')]) The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception. -------------------- The above exception was the direct cause of the following exception: KeyError Traceback (most recent call last) Cell In [1], line 35 33 for i in range(100): 34 (prng_key,) = jax.random.split(prng_key, 1) ---> 35 params = update(params, prng_key, x, t, 1e-3) [... skipping hidden 11 frame] Cell In [1], line 28, in update(params, prng_key, x, t, eta) 26 @jax.jit 27 def update(params, prng_key, x, t, eta=1e-3): ---> 28 grads = jax.grad(loss)(params, prng_key, x, t) 29 return tree_map(lambda p,g: p - eta * g, params, grads) [... skipping hidden 19 frame] File ~/GP/env/causal-dev/lib/python3.8/site-packages/jax/experimental/ode.py:259, in _odeint_rev(func, rtol, atol, mxstep, hmax, res, g) 256 return (y_bar, t0_bar, args_bar), t_bar 258 init_carry = (g[-1], 0., tree_map(jnp.zeros_like, args)) --> 259 (y_bar, t0_bar, args_bar), rev_ts_bar = lax.scan( 260 scan_fun, init_carry, jnp.arange(len(ts) - 1, 0, -1)) 261 ts_bar = jnp.concatenate([jnp.array([t0_bar]), rev_ts_bar[::-1]]) 262 return (y_bar, ts_bar, *args_bar) [... skipping hidden 9 frame] File ~/GP/env/causal-dev/lib/python3.8/site-packages/jax/experimental/ode.py:249, in _odeint_rev..scan_fun(carry, i) 247 t0_bar = t0_bar - t_bar 248 # Run augmented system backwards to previous observation --> 249 _, y_bar, t0_bar, args_bar = odeint( 250 aug_dynamics, (ys[i], y_bar, t0_bar, args_bar), 251 jnp.array([-ts[i], -ts[i - 1]]), 252 *args, rtol=rtol, atol=atol, mxstep=mxstep, hmax=hmax) 253 y_bar, t0_bar, args_bar = tree_map(op.itemgetter(1), (y_bar, t0_bar, args_bar)) 254 # Add gradient from current output File ~/GP/env/causal-dev/lib/python3.8/site-packages/jax/experimental/ode.py:179, in odeint(func, y0, t, rtol, atol, mxstep, hmax, *args) 176 raise TypeError(f"t must be an array of floats, but got {t}.") 178 converted, consts = custom_derivatives.closure_convert(func, y0, t[0], *args) --> 179 return _odeint_wrapper(converted, rtol, atol, mxstep, hmax, y0, t, *args, *consts) [... skipping hidden 5 frame] File ~/GP/env/causal-dev/lib/python3.8/site-packages/jax/experimental/ode.py:185, in _odeint_wrapper(func, rtol, atol, mxstep, hmax, y0, ts, *args) 183 y0, unravel = ravel_pytree(y0) 184 func = ravel_first_arg(func, unravel) --> 185 out = _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args) 186 return jax.vmap(unravel)(out) [... skipping hidden 6 frame] File ~/GP/env/causal-dev/lib/python3.8/site-packages/jax/experimental/ode.py:216, in _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args) 213 y_target = jnp.polyval(interp_coeff, relative_output_time.astype(interp_coeff.dtype)) 214 return carry, y_target --> 216 f0 = func_(y0, ts[0]) 217 dt = jnp.clip(initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0), a_min=0., a_max=hmax) 218 interp_coeff = jnp.array([y0] * 5) File ~/GP/env/causal-dev/lib/python3.8/site-packages/jax/experimental/ode.py:190, in _odeint..(y, t) 188 @partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2, 3, 4)) 189 def _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args): --> 190 func_ = lambda y, t: func(y, t, *args) 192 def scan_fun(carry, target_t): 194 def cond_fun(state): [... skipping hidden 1 frame] File ~/GP/env/causal-dev/lib/python3.8/site-packages/jax/experimental/ode.py:54, in ravel_first_arg_(unravel, y_flat, *args) 52 y = unravel(y_flat) 53 ans = yield (y,) + args, {} ---> 54 ans_flat, _ = ravel_pytree(ans) 55 yield ans_flat File ~/GP/env/causal-dev/lib/python3.8/site-packages/jax/_src/flatten_util.py:49, in ravel_pytree(pytree) 30 """Ravel (flatten) a pytree of arrays down to a 1D array. 31 32 Args: (...) 46 47 """ 48 leaves, treedef = tree_flatten(pytree) ---> 49 flat, unravel_list = _ravel_list(leaves) 50 unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat)) 51 return flat, unravel_pytree File ~/GP/env/causal-dev/lib/python3.8/site-packages/jax/_src/flatten_util.py:56, in _ravel_list(lst) 54 if not lst: return jnp.array([], jnp.float32), lambda _: [] 55 from_dtypes = [dtypes.dtype(l) for l in lst] ---> 56 to_dtype = dtypes.result_type(*from_dtypes) 57 sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst) 58 indices = np.cumsum(sizes) [... skipping hidden 3 frame] File ~/GP/env/causal-dev/lib/python3.8/site-packages/jax/_src/dtypes.py:394, in (.0) 392 N = set(nodes) 393 UB = _lattice_upper_bounds[jax_numpy_dtype_promotion] --> 394 CUB = set.intersection(*(UB[n] for n in N)) 395 LUB = (CUB & N) or {c for c in CUB if CUB.issubset(UB[c])} 396 if len(LUB) == 1: KeyError: dtype([('float0', 'V')]) ```

The issue in the provided example is actually solved by the following modification:

def loss(params, prng_key, x, t):
    x_hat = odeint(lambda x, t: f(params, x, t, prng_key), x[0], t)
    return jnp.mean((x_hat - x[-1])**2)

However, my application is more complicated than this, where the above solution is non-applicable. If understand why the problem is raised and why it is solved by the above modification, then I think I can address this issue in my program.

What jax/jaxlib version are you using?

jax==0.3.25, jaxlib==0.3.25+cuda11.cudnn82

Which accelerator(s) are you using?

GPU

Additional system info

Linux

NVIDIA GPU info

No response

Update on 23/11/2022

It seems impossible to pass a PRNG to odeint's *args, so I have changed the approach of my program. Now I sample from probability distributions and pass the random samples (floats) to *args instead of sampling inside the odeint.

apaszke commented 1 year ago

Why do you need to pass the key to odeint? Is the dynamics function stochastic? If you're looking for ways to integrate SDEs then you might want to take a look at Diffrax.

I think the best fix for this issue would be to add a better error message. Otherwise it's probably intended that it doesn't work.

A-Alaa commented 1 year ago

Why do you need to pass the key to odeint? Is the dynamics function stochastic? If you're looking for ways to integrate SDEs then you might want to take a look at Diffrax.

I think the best fix for this issue would be to add a better error message. Otherwise it's probably intended that it doesn't work.

It is still an ODE problem in my actual program, but with L0-regularisation for sparsity over the dynamics. The idea appeared recently here and implemented in PyTorch here. It is inspired by this work, where it is originally applied to feed-forward networks.

I have managed to overcome this and avoid passing PRNG-keys as I mentioned at the end:

It seems impossible to pass a PRNG to odeint's args, so I have changed the approach of my program. Now I sample from probability distributions and pass the random samples (floats) to args instead of sampling inside the odeint.

Thank you!

froystig commented 1 year ago

Yeah, it would be good to improve this error, especially ifode graduates from experimental. Glad you found a workaround, and thanks for filing!