Open SamuelBrand1 opened 4 months ago
I've raised an issue with numpyro
so if there is an already existing solution/utility that will get flagged. https://github.com/pyro-ppl/numpyro/issues/1833
Some more context on XLARuntimeError and max_step error:
Some failures appear early:
Running chain 3: 0%| | 0/2000 [02:29<?, ?it/s][A[A[Ajax.pure_callback failed
Traceback (most recent call last):
File "/usr/local/lib/python3.10/site-packages/jax/_src/callback.py", line 77, in pure_callback_impl
return callback(*args)
File "/usr/local/lib/python3.10/site-packages/jax/_src/callback.py", line 65, in __call__
return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
File "/usr/local/lib/python3.10/site-packages/equinox/_errors.py", line 70, in raises
raise EqxRuntimeError(msgs[_index.item()])
equinox._errors.EqxRuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`.
Sometimes it's late (but still warm up?)
Running chain 1: 45%|####5 | 900/2000 [2:19:32<1:31:54, 5.01s/it][Ajax.pure_callback failed
Traceback (most recent call last):
File "/usr/local/lib/python3.10/site-packages/jax/_src/callback.py", line 77, in pure_callback_impl
return callback(*args)
File "/usr/local/lib/python3.10/site-packages/jax/_src/callback.py", line 65, in __call__
return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
File "/usr/local/lib/python3.10/site-packages/equinox/_errors.py", line 70, in raises
raise EqxRuntimeError(msgs[_index.item()])
equinox._errors.EqxRuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`.
Sometimes only one or two in the four chains failed, sometimes all four chains failed. Regardless, the fit ends with following errors:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/input/exp/fifty_state_6strain_2202_2407/smh_6str_prelim_5/run_task.py", line 264, in <module>
runner.process_state(state, jobid, jobid_in_path=True)
File "/input/exp/fifty_state_6strain_2202_2407/smh_6str_prelim_5/run_task.py", line 221, in process_state
inferer.infer(
File "/input/exp/fifty_state_6strain_2202_2407/smh_6str_prelim_5/inferer_smh.py", line 59, in infer
self.inference_algo.run(
File "/usr/local/lib/python3.10/site-packages/numpyro/infer/mcmc.py", line 678, in run
states_flat = tree_map(
File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 321, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 321, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/usr/local/lib/python3.10/site-packages/numpyro/infer/mcmc.py", line 680, in <lambda>
lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],) + x.shape[2:]),
File "/usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 787, in reshape
return a.reshape(newshape, order=order) # type: ignore[call-overload,union-attr]
File "/usr/local/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 150, in _reshape
return lax.reshape(a, newshape, None)
File "/usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 892, in reshape
return reshape_p.bind(
File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 422, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 425, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 913, in process_primitive
return primitive.impl(*tracers, **params)
File "/usr/local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
outs = fun(*args)
File "/usr/local/lib/python3.10/site-packages/jax/_src/array.py", line 940, in _array_shard_arg
return shard_sharded_device_array_slow_path(x, devices, indices, sharding)
File "/usr/local/lib/python3.10/site-packages/jax/_src/array.py", line 908, in shard_sharded_device_array_slow_path
return pxla.shard_arg(x._value, sharding, canonicalize=False)
File "/usr/local/lib/python3.10/site-packages/jax/_src/array.py", line 640, in _value
npy_value[ind] = arr._single_device_array_to_np_array()
jaxlib.xla_extension.XlaRuntimeError
tagging wontfix
as its parent issue #244 is also tagged that for now
From f2f discussion there are occasional stochastic failures in the NUTS sampling procedure. The characteristics of these failures are:
These are suggestive of either sampling into a numerically unstable portion parameter space during warm up and/or some kind of numerical instability associated with vaccination rates.
A good first step towards isolation of the problem would be a stress test utility e.g like this. Its worth investigating the existing numpyro utilities before rolling our own solution.