google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.75k stars 2.71k forks source link

No error message when `odeint` fails #7326

Open charlesm93 opened 3 years ago

charlesm93 commented 3 years ago

Please:

When odeint fails to meet the target precision, as determined by the relative and absolute tolerance, no error or warning message is issued. Here's a toy example demonstrating the problem:

t = np.array([0.5, 0.75, 1, 1.25, 1.5, 2, 3, 4, 5, 6])
y0 = np.array([100.0, 0.0])

theta = np.array([1.5, 1])
def system(state, time, theta):
  k1 = theta[0]
  k2 = theta[1]

  return jnp.array([
    - k1 * state[0],
    k1 * state[0] - k2 * state[1]
  ])

states = odeint(system, y0, t, theta, rtol = 1e-6, atol = 1e-6, mxstep = 1e3)

With this max number of steps, the final state returned by odeint is

[2.61259228e-02 1.14765358e+00]

Now, let's set mxstep = 1. Then odeint returns

[ 1720.9906   -3944.0195  ]

which is patently wrong. There are however no warning messages suggesting this solution might not be reliable.

shoyer commented 3 years ago

I think the best we could do here is return NaN instead of a valid answer, since it isn't possible to raise errors from within XLA.

charlesm93 commented 3 years ago

Returning NaN would certainly be an improvement.

I do find warning and error messages extremely helpful, especially when writing intricate code. Specifying a model with TensorFlow Probability built on JAX becomes quite challenging when I need to hunt down where in the code the NaN comes from. By contrast, a language such as Stan provides detailed information about what caused the issue (e.g. "ode integrator reached the maximum number of steps", "cannot evaluate gamma density for a negative variable", etc ). Then again, if this is a constraint imposed by XLA, I'm not sure what the best course of action might be.

shoyer commented 3 years ago

Well, we can certainly imagine passing status codes around explicitly. We just can't use "exceptions", since those don't exist in XLA HLO.

froystig commented 3 years ago

How about simply documenting this more clearly? The caller sets both the maximum step count and the tolerance, so it might be reasonable for them to expect a "whichever happens first" policy, provided we make that clear.