StanfordASL / hj_reachability

Hamilton-Jacobi reachability analysis in JAX.
MIT License
114 stars 19 forks source link

Higher-Order Upwind Schemes Failing #15

Closed Shono1 closed 1 month ago

Shono1 commented 5 months ago

I've recently installed this library, and have run into an issue when trying to run the starter code with GPU-accelerated Jax. For all solver resolutions except low, I get an operand dimension mismatch error thrown from diff_coefficients. Here's the traceback from trying to run hj.step on the unmodified quickstart.ipynb:

`---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[4], line 3
      1 time = 0.
      2 target_time = -2.8
----> 3 target_values = hj.step(solver_settings, dynamics, grid, time, values, target_time)

    [... skipping hidden 11 frame]

File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/solver.py:76, in step(solver_settings, dynamics, grid, time, values, target_time, progress_bar)
     73         bar.update_to(jnp.abs(t - bar.reference_time))
     74     return t, v
---> 76 return jax.lax.while_loop(lambda time_values: jnp.abs(target_time - time_values[0]) > 0, sub_step,
     77                           (time, values))[1]

    [... skipping hidden 9 frame]

File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/solver.py:71, in step.<locals>.sub_step(time_values)
     70 def sub_step(time_values):
---> 71     t, v = solver_settings.time_integrator(solver_settings, dynamics, grid, *time_values, target_time)
     72     if bar is not False:
     73         bar.update_to(jnp.abs(t - bar.reference_time))

File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/time_integration.py:50, in third_order_total_variation_diminishing_runge_kutta(solver_settings, dynamics, grid, time, values, target_time)
     49 def third_order_total_variation_diminishing_runge_kutta(solver_settings, dynamics, grid, time, values, target_time):
---> 50     time_1, values_1 = euler_step(solver_settings, dynamics, grid, time, values, max_time_step=target_time - time)
     51     time_step = time_1 - time
     52     _, values_2 = euler_step(solver_settings, dynamics, grid, time_1, values_1, time_step)

    [... skipping hidden 11 frame]

File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/time_integration.py:21, in euler_step(solver_settings, dynamics, grid, time, values, time_step, max_time_step)
     19 time_direction = jnp.sign(max_time_step) if time_step is None else jnp.sign(time_step)
     20 signed_hamiltonian = lambda *args, **kwargs: time_direction * dynamics.hamiltonian(*args, **kwargs)
---> 21 left_grad_values, right_grad_values = grid.upwind_grad_values(solver_settings.upwind_scheme, values)
     22 dissipation_coefficients = solver_settings.artificial_dissipation_scheme(dynamics.partial_max_magnitudes,
     23                                                                          grid.states, time, values,
     24                                                                          left_grad_values, right_grad_values)
     25 dvalues_dt = -solver_settings.hamiltonian_postprocessor(time_direction * utils.multivmap(
     26     lambda state, value, left_grad_value, right_grad_value, dissipation_coefficients:
     27     (lax_friedrichs_numerical_hamiltonian(signed_hamiltonian, state, time, value,
     28                                           left_grad_value, right_grad_value, dissipation_coefficients)),
     29     np.arange(grid.ndim))(grid.states, values, left_grad_values, right_grad_values, dissipation_coefficients))

File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/grid.py:89, in Grid.upwind_grad_values(self, upwind_scheme, values)
     87 def upwind_grad_values(self, upwind_scheme: Callable, values: Array) -> Tuple[Array, Array]:
     88     """Returns `(left_grad_values, right_grad_values)`."""
---> 89     left_derivatives, right_derivatives = zip(*[
     90         utils.multivmap(lambda values: upwind_scheme(values, spacing, boundary_condition),
     91                         np.array([j
     92                                   for j in range(self.ndim)
     93                                   if j != i]))(values)
     94         for i, (spacing, boundary_condition) in enumerate(zip(self.spacings, self.boundary_conditions))
     95     ])
     96     return (jnp.stack(left_derivatives, -1), jnp.stack(right_derivatives, -1))

File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/grid.py:90, in <listcomp>(.0)
     87 def upwind_grad_values(self, upwind_scheme: Callable, values: Array) -> Tuple[Array, Array]:
     88     """Returns `(left_grad_values, right_grad_values)`."""
     89     left_derivatives, right_derivatives = zip(*[
---> 90         utils.multivmap(lambda values: upwind_scheme(values, spacing, boundary_condition),
     91                         np.array([j
     92                                   for j in range(self.ndim)
     93                                   if j != i]))(values)
     94         for i, (spacing, boundary_condition) in enumerate(zip(self.spacings, self.boundary_conditions))
     95     ])
     96     return (jnp.stack(left_derivatives, -1), jnp.stack(right_derivatives, -1))

    [... skipping hidden 6 frame]

File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/grid.py:90, in Grid.upwind_grad_values.<locals>.<lambda>(values)
     87 def upwind_grad_values(self, upwind_scheme: Callable, values: Array) -> Tuple[Array, Array]:
     88     """Returns `(left_grad_values, right_grad_values)`."""
     89     left_derivatives, right_derivatives = zip(*[
---> 90         utils.multivmap(lambda values: upwind_scheme(values, spacing, boundary_condition),
     91                         np.array([j
     92                                   for j in range(self.ndim)
     93                                   if j != i]))(values)
     94         for i, (spacing, boundary_condition) in enumerate(zip(self.spacings, self.boundary_conditions))
     95     ])
     96     return (jnp.stack(left_derivatives, -1), jnp.stack(right_derivatives, -1))

File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/finite_differences/upwind_first.py:42, in weighted_essentially_non_oscillatory(eno_order, values, spacing, boundary_condition)
     37 if eno_order == 1:
     38     return (diffs[:-1], diffs[1:])
     40 substencil_approximations = tuple(
     41     _unrolled_correlate(diffs[i:len(diffs) - eno_order + i], c)
---> 42     for (i, c) in enumerate(_diff_coefficients(eno_order)))
     43 diffs2 = diffs[1:] - diffs[:-1]
     44 smoothness_indicators = [
     45     sum(
     46         _unrolled_correlate(diffs2[i + j:len(diffs2) - eno_order + i + 1], L[j:, j])**2
     47         for j in range(eno_order - 1))
     48     for (i, L) in enumerate(np.linalg.cholesky(_smoothness_indicator_quad_form(eno_order)))
     49 ]

File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/finite_differences/upwind_first.py:167, in _diff_coefficients(k, stencil)
    164     elif k != stencil.shape[-1] - 1:
    165         raise ValueError("`k` must match `stencil.shape[-1] - 1` if both arguments are provided; got "
    166                          f"{(k, stencil.shape[-1] - 1)}.")
--> 167 return np.linalg.solve(
    168     np.diff(poly.polyvander(stencil, k), axis=-2)[..., 1:].swapaxes(-1, -2),
    169     np.eye(k)[(np.newaxis,) * (stencil.ndim - 1) + (0,)])

File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/numpy/linalg/_linalg.py:410, in solve(a, b)
    407 signature = 'DD->D' if isComplexType(t) else 'dd->d'
    408 with errstate(call=_raise_linalgerror_singular, invalid='call',
    409               over='ignore', divide='ignore', under='ignore'):
--> 410     r = gufunc(a, b, signature=signature)
    412 return wrap(r.astype(result_t, copy=False))

ValueError: solve: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (m,m),(m,n)->(m,n) (size 1 is different from 3)`

When I knock the resolution down to low, however, there is no error. I also had success running high and very high resolution calculations on the CPU only version of JAX.

Update: I started a new CPU only environment and it actually threw the same error as I was getting with the GPU one. I'm presently unsure how I ever got this to work.

My current setup:

I've also implemented a fix suggested by @mattkiim #14

Let me know if you need any more info or need me to run any tests.

gonultasbu commented 4 months ago

having the same issue

I downgraded to python==3.8, jax==4.10 and jaxlib==4.10 and error seems to be gone for now.

schmrlng commented 1 month ago

Ah surprisingly this actually isn't a JAX version issue, but a NumPy one instead -- see https://github.com/numpy/numpy/issues/26421, https://github.com/numpy/numpy/issues/15349, or the NumPy 2.0.0 Release Notes for context. I'm guessing that when it was working you were on some NumPy 1.XX version, but upgraded to NumPy 2.XX at some point which broke the package as you observed.

Apologies for the significant delay in fixing this issue/I'm not sure if you're still working with this package (hopefully your REU went well!), but after #16 this should now work in both NumPy 1 and 2.