CDCgov / DynODE

CDC/CFA/Predict/Scenarios ODE Modeling Framework
Apache License 2.0
3 stars 0 forks source link

Adding stress testing utility for model inference to help isolate cause of stochastic failure in inference #197

Open SamuelBrand1 opened 4 months ago

SamuelBrand1 commented 4 months ago

From f2f discussion there are occasional stochastic failures in the NUTS sampling procedure. The characteristics of these failures are:

These are suggestive of either sampling into a numerically unstable portion parameter space during warm up and/or some kind of numerical instability associated with vaccination rates.

A good first step towards isolation of the problem would be a stress test utility e.g like this. Its worth investigating the existing numpyro utilities before rolling our own solution.

SamuelBrand1 commented 4 months ago

I've raised an issue with numpyro so if there is an already existing solution/utility that will get flagged. https://github.com/pyro-ppl/numpyro/issues/1833

kokbent commented 4 months ago

Some more context on XLARuntimeError and max_step error:

Some failures appear early:

Running chain 3:   0%|          | 0/2000 [02:29<?, ?it/s]jax.pure_callback failed
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/site-packages/jax/_src/callback.py", line 77, in pure_callback_impl
    return callback(*args)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/callback.py", line 65, in __call__
    return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
  File "/usr/local/lib/python3.10/site-packages/equinox/_errors.py", line 70, in raises
    raise EqxRuntimeError(msgs[_index.item()])
equinox._errors.EqxRuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`.

Sometimes it's late (but still warm up?)

Running chain 1:  45%|####5     | 900/2000 [2:19:32<1:31:54,  5.01s/it]jax.pure_callback failed
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/site-packages/jax/_src/callback.py", line 77, in pure_callback_impl
    return callback(*args)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/callback.py", line 65, in __call__
    return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
  File "/usr/local/lib/python3.10/site-packages/equinox/_errors.py", line 70, in raises
    raise EqxRuntimeError(msgs[_index.item()])
equinox._errors.EqxRuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`.

Sometimes only one or two in the four chains failed, sometimes all four chains failed. Regardless, the fit ends with following errors:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/input/exp/fifty_state_6strain_2202_2407/smh_6str_prelim_5/run_task.py", line 264, in <module>
    runner.process_state(state, jobid, jobid_in_path=True)
  File "/input/exp/fifty_state_6strain_2202_2407/smh_6str_prelim_5/run_task.py", line 221, in process_state
    inferer.infer(
  File "/input/exp/fifty_state_6strain_2202_2407/smh_6str_prelim_5/inferer_smh.py", line 59, in infer
    self.inference_algo.run(
  File "/usr/local/lib/python3.10/site-packages/numpyro/infer/mcmc.py", line 678, in run
    states_flat = tree_map(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 321, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 321, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/usr/local/lib/python3.10/site-packages/numpyro/infer/mcmc.py", line 680, in <lambda>
    lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],) + x.shape[2:]),
  File "/usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 787, in reshape
    return a.reshape(newshape, order=order)  # type: ignore[call-overload,union-attr]
  File "/usr/local/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 150, in _reshape
    return lax.reshape(a, newshape, None)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 892, in reshape
    return reshape_p.bind(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 422, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 425, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 913, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/array.py", line 940, in _array_shard_arg
    return shard_sharded_device_array_slow_path(x, devices, indices, sharding)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/array.py", line 908, in shard_sharded_device_array_slow_path
    return pxla.shard_arg(x._value, sharding, canonicalize=False)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/array.py", line 640, in _value
    npy_value[ind] = arr._single_device_array_to_np_array()
jaxlib.xla_extension.XlaRuntimeError
arik-shurygin commented 2 weeks ago

tagging wontfix as its parent issue #244 is also tagged that for now