Open A-Alaa opened 1 year ago
Why do you need to pass the key to odeint? Is the dynamics function stochastic? If you're looking for ways to integrate SDEs then you might want to take a look at Diffrax.
I think the best fix for this issue would be to add a better error message. Otherwise it's probably intended that it doesn't work.
Why do you need to pass the key to odeint? Is the dynamics function stochastic? If you're looking for ways to integrate SDEs then you might want to take a look at Diffrax.
I think the best fix for this issue would be to add a better error message. Otherwise it's probably intended that it doesn't work.
It is still an ODE problem in my actual program, but with L0-regularisation for sparsity over the dynamics. The idea appeared recently here and implemented in PyTorch here. It is inspired by this work, where it is originally applied to feed-forward networks.
I have managed to overcome this and avoid passing PRNG-keys as I mentioned at the end:
It seems impossible to pass a PRNG to odeint's args, so I have changed the approach of my program. Now I sample from probability distributions and pass the random samples (floats) to args instead of sampling inside the odeint.
Thank you!
Yeah, it would be good to improve this error, especially ifode
graduates from experimental
. Glad you found a workaround, and thanks for filing!
Description
I am trying to pass
PRNGKey
s to a function, which is integrated byodeint
. Here is a simplified example reproducing the problem:Which raises the following exception:
KeyError: dtype([('float0', 'V')])
``` --------------------------------------------------------------------------- JaxStackTraceBeforeTransformation Traceback (most recent call last) File ~/GP/env/causal-dev/lib/python3.8/runpy.py:194, in _run_module_as_main(***failed resolving arguments***) 193 sys.argv[0] = mod_spec.origin --> 194 return _run_code(code, main_globals, None, 195 "__main__", mod_spec) File ~/GP/env/causal-dev/lib/python3.8/runpy.py:87, in _run_code(***failed resolving arguments***) 80 run_globals.update(__name__ = mod_name, 81 __file__ = fname, 82 __cached__ = cached, (...) 85 __package__ = pkg_name, 86 __spec__ = mod_spec) ---> 87 exec(code, run_globals) 88 return run_globals File ~/GP/env/causal-dev/lib/python3.8/site-packages/ipykernel_launcher.py:17 15 from ipykernel import kernelapp as app ---> 17 app.launch_new_instance() File ~/GP/env/causal-dev/lib/python3.8/site-packages/traitlets/config/application.py:976, in Application.launch_instance(***failed resolving arguments***) 975 app.initialize(argv) --> 976 app.start() File ~/GP/env/causal-dev/lib/python3.8/site-packages/ipykernel/kernelapp.py:712, in IPKernelApp.start(***failed resolving arguments***) 711 try: --> 712 self.io_loop.start() 713 except KeyboardInterrupt: File ~/GP/env/causal-dev/lib/python3.8/site-packages/tornado/platform/asyncio.py:215, in BaseAsyncIOLoop.start(***failed resolving arguments***) 214 asyncio.set_event_loop(self.asyncio_loop) --> 215 self.asyncio_loop.run_forever() 216 finally: File ~/GP/env/causal-dev/lib/python3.8/asyncio/base_events.py:570, in BaseEventLoop.run_forever(***failed resolving arguments***) 569 while True: --> 570 self._run_once() 571 if self._stopping: File ~/GP/env/causal-dev/lib/python3.8/asyncio/base_events.py:1859, in BaseEventLoop._run_once(***failed resolving arguments***) 1858 else: -> 1859 handle._run() 1860 handle = None File ~/GP/env/causal-dev/lib/python3.8/asyncio/events.py:81, in Handle._run(***failed resolving arguments***) 80 try: ---> 81 self._context.run(self._callback, *self._args) 82 except (SystemExit, KeyboardInterrupt): File ~/GP/env/causal-dev/lib/python3.8/site-packages/ipykernel/kernelbase.py:510, in Kernel.dispatch_queue(***failed resolving arguments***) 509 try: --> 510 await self.process_one() 511 except Exception: File ~/GP/env/causal-dev/lib/python3.8/site-packages/ipykernel/kernelbase.py:499, in Kernel.process_one(***failed resolving arguments***) 498 return None --> 499 await dispatch(*args) File ~/GP/env/causal-dev/lib/python3.8/site-packages/ipykernel/kernelbase.py:406, in Kernel.dispatch_shell(***failed resolving arguments***) 405 if inspect.isawaitable(result): --> 406 await result 407 except Exception: File ~/GP/env/causal-dev/lib/python3.8/site-packages/ipykernel/kernelbase.py:730, in Kernel.execute_request(***failed resolving arguments***) 729 if inspect.isawaitable(reply_content): --> 730 reply_content = await reply_content 732 # Flush output before sending the reply. File ~/GP/env/causal-dev/lib/python3.8/site-packages/ipykernel/ipkernel.py:383, in IPythonKernel.do_execute(***failed resolving arguments***) 382 if with_cell_id: --> 383 res = shell.run_cell( 384 code, 385 store_history=store_history, 386 silent=silent, 387 cell_id=cell_id, 388 ) 389 else: File ~/GP/env/causal-dev/lib/python3.8/site-packages/ipykernel/zmqshell.py:528, in ZMQInteractiveShell.run_cell(***failed resolving arguments***) 527 self._last_traceback = None --> 528 return super().run_cell(*args, **kwargs) File ~/GP/env/causal-dev/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2885, in InteractiveShell.run_cell(***failed resolving arguments***) 2884 try: -> 2885 result = self._run_cell( 2886 raw_cell, store_history, silent, shell_futures, cell_id 2887 ) 2888 finally: File ~/GP/env/causal-dev/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2940, in InteractiveShell._run_cell(***failed resolving arguments***) 2939 try: -> 2940 return runner(coro) 2941 except BaseException as e: File ~/GP/env/causal-dev/lib/python3.8/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner(***failed resolving arguments***) 128 try: --> 129 coro.send(None) 130 except StopIteration as exc: File ~/GP/env/causal-dev/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3139, in InteractiveShell.run_cell_async(***failed resolving arguments***) 3137 interactivity = "none" if silent else self.ast_node_interactivity -> 3139 has_raised = await self.run_ast_nodes(code_ast.body, cell_name, 3140 interactivity=interactivity, compiler=compiler, result=result) 3142 self.last_execution_succeeded = not has_raised File ~/GP/env/causal-dev/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3318, in InteractiveShell.run_ast_nodes(***failed resolving arguments***) 3317 asy = compare(code) -> 3318 if await self.run_code(code, result, async_=asy): 3319 return True File ~/GP/env/causal-dev/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3378, in InteractiveShell.run_code(***failed resolving arguments***) 3377 else: -> 3378 exec(code_obj, self.user_global_ns, self.user_ns) 3379 finally: 3380 # Reset our crash handler in place Cell In [1], line 35 34 (prng_key,) = jax.random.split(prng_key, 1) ---> 35 params = update(params, prng_key, x, t, 1e-3) Cell In [1], line 28, in update(***failed resolving arguments***) 26 @jax.jit 27 def update(params, prng_key, x, t, eta=1e-3): ---> 28 grads = jax.grad(loss)(params, prng_key, x, t) 29 return tree_map(lambda p,g: p - eta * g, params, grads) Cell In [1], line 23, in loss(***failed resolving arguments***) 22 def loss(params, prng_key, x, t): ---> 23 x_hat = odeint(partial(f, params), x[0], t, prng_key) 24 return jnp.mean((x_hat - x[-1])**2) File ~/GP/env/causal-dev/lib/python3.8/site-packages/jax/experimental/ode.py:179, in odeint(***failed resolving arguments***) 178 converted, consts = custom_derivatives.closure_convert(func, y0, t[0], *args) --> 179 return _odeint_wrapper(converted, rtol, atol, mxstep, hmax, y0, t, *args, *consts) File ~/GP/env/causal-dev/lib/python3.8/site-packages/jax/experimental/ode.py:185, in _odeint_wrapper(***failed resolving arguments***) 184 func = ravel_first_arg(func, unravel) --> 185 out = _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args) 186 return jax.vmap(unravel)(out) JaxStackTraceBeforeTransformation: KeyError: dtype([('float0', 'V')]) The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception. -------------------- The above exception was the direct cause of the following exception: KeyError Traceback (most recent call last) Cell In [1], line 35 33 for i in range(100): 34 (prng_key,) = jax.random.split(prng_key, 1) ---> 35 params = update(params, prng_key, x, t, 1e-3) [... skipping hidden 11 frame] Cell In [1], line 28, in update(params, prng_key, x, t, eta) 26 @jax.jit 27 def update(params, prng_key, x, t, eta=1e-3): ---> 28 grads = jax.grad(loss)(params, prng_key, x, t) 29 return tree_map(lambda p,g: p - eta * g, params, grads) [... skipping hidden 19 frame] File ~/GP/env/causal-dev/lib/python3.8/site-packages/jax/experimental/ode.py:259, in _odeint_rev(func, rtol, atol, mxstep, hmax, res, g) 256 return (y_bar, t0_bar, args_bar), t_bar 258 init_carry = (g[-1], 0., tree_map(jnp.zeros_like, args)) --> 259 (y_bar, t0_bar, args_bar), rev_ts_bar = lax.scan( 260 scan_fun, init_carry, jnp.arange(len(ts) - 1, 0, -1)) 261 ts_bar = jnp.concatenate([jnp.array([t0_bar]), rev_ts_bar[::-1]]) 262 return (y_bar, ts_bar, *args_bar) [... skipping hidden 9 frame] File ~/GP/env/causal-dev/lib/python3.8/site-packages/jax/experimental/ode.py:249, in _odeint_rev.The issue in the provided example is actually solved by the following modification:
However, my application is more complicated than this, where the above solution is non-applicable. If understand why the problem is raised and why it is solved by the above modification, then I think I can address this issue in my program.
What jax/jaxlib version are you using?
jax==0.3.25, jaxlib==0.3.25+cuda11.cudnn82
Which accelerator(s) are you using?
GPU
Additional system info
Linux
NVIDIA GPU info
No response
Update on 23/11/2022
It seems impossible to pass a PRNG to
odeint
's*args
, so I have changed the approach of my program. Now I sample from probability distributions and pass the random samples (floats) to*args
instead of sampling inside theodeint
.