StanfordASL / hj_reachability

Hamilton-Jacobi reachability analysis in JAX.
MIT License
102 stars 16 forks source link

Higher-Order Upwind Schemes Failing #15

Open Shono1 opened 2 months ago

Shono1 commented 2 months ago

I've recently installed this library, and have run into an issue when trying to run the starter code with GPU-accelerated Jax. For all solver resolutions except low, I get an operand dimension mismatch error thrown from diff_coefficients. Here's the traceback from trying to run hj.step on the unmodified quickstart.ipynb:

`---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[4], line 3
      1 time = 0.
      2 target_time = -2.8
----> 3 target_values = hj.step(solver_settings, dynamics, grid, time, values, target_time)

    [... skipping hidden 11 frame]

File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/solver.py:76, in step(solver_settings, dynamics, grid, time, values, target_time, progress_bar)
     73         bar.update_to(jnp.abs(t - bar.reference_time))
     74     return t, v
---> 76 return jax.lax.while_loop(lambda time_values: jnp.abs(target_time - time_values[0]) > 0, sub_step,
     77                           (time, values))[1]

    [... skipping hidden 9 frame]

File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/solver.py:71, in step.<locals>.sub_step(time_values)
     70 def sub_step(time_values):
---> 71     t, v = solver_settings.time_integrator(solver_settings, dynamics, grid, *time_values, target_time)
     72     if bar is not False:
     73         bar.update_to(jnp.abs(t - bar.reference_time))

File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/time_integration.py:50, in third_order_total_variation_diminishing_runge_kutta(solver_settings, dynamics, grid, time, values, target_time)
     49 def third_order_total_variation_diminishing_runge_kutta(solver_settings, dynamics, grid, time, values, target_time):
---> 50     time_1, values_1 = euler_step(solver_settings, dynamics, grid, time, values, max_time_step=target_time - time)
     51     time_step = time_1 - time
     52     _, values_2 = euler_step(solver_settings, dynamics, grid, time_1, values_1, time_step)

    [... skipping hidden 11 frame]

File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/time_integration.py:21, in euler_step(solver_settings, dynamics, grid, time, values, time_step, max_time_step)
     19 time_direction = jnp.sign(max_time_step) if time_step is None else jnp.sign(time_step)
     20 signed_hamiltonian = lambda *args, **kwargs: time_direction * dynamics.hamiltonian(*args, **kwargs)
---> 21 left_grad_values, right_grad_values = grid.upwind_grad_values(solver_settings.upwind_scheme, values)
     22 dissipation_coefficients = solver_settings.artificial_dissipation_scheme(dynamics.partial_max_magnitudes,
     23                                                                          grid.states, time, values,
     24                                                                          left_grad_values, right_grad_values)
     25 dvalues_dt = -solver_settings.hamiltonian_postprocessor(time_direction * utils.multivmap(
     26     lambda state, value, left_grad_value, right_grad_value, dissipation_coefficients:
     27     (lax_friedrichs_numerical_hamiltonian(signed_hamiltonian, state, time, value,
     28                                           left_grad_value, right_grad_value, dissipation_coefficients)),
     29     np.arange(grid.ndim))(grid.states, values, left_grad_values, right_grad_values, dissipation_coefficients))

File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/grid.py:89, in Grid.upwind_grad_values(self, upwind_scheme, values)
     87 def upwind_grad_values(self, upwind_scheme: Callable, values: Array) -> Tuple[Array, Array]:
     88     """Returns `(left_grad_values, right_grad_values)`."""
---> 89     left_derivatives, right_derivatives = zip(*[
     90         utils.multivmap(lambda values: upwind_scheme(values, spacing, boundary_condition),
     91                         np.array([j
     92                                   for j in range(self.ndim)
     93                                   if j != i]))(values)
     94         for i, (spacing, boundary_condition) in enumerate(zip(self.spacings, self.boundary_conditions))
     95     ])
     96     return (jnp.stack(left_derivatives, -1), jnp.stack(right_derivatives, -1))

File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/grid.py:90, in <listcomp>(.0)
     87 def upwind_grad_values(self, upwind_scheme: Callable, values: Array) -> Tuple[Array, Array]:
     88     """Returns `(left_grad_values, right_grad_values)`."""
     89     left_derivatives, right_derivatives = zip(*[
---> 90         utils.multivmap(lambda values: upwind_scheme(values, spacing, boundary_condition),
     91                         np.array([j
     92                                   for j in range(self.ndim)
     93                                   if j != i]))(values)
     94         for i, (spacing, boundary_condition) in enumerate(zip(self.spacings, self.boundary_conditions))
     95     ])
     96     return (jnp.stack(left_derivatives, -1), jnp.stack(right_derivatives, -1))

    [... skipping hidden 6 frame]

File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/grid.py:90, in Grid.upwind_grad_values.<locals>.<lambda>(values)
     87 def upwind_grad_values(self, upwind_scheme: Callable, values: Array) -> Tuple[Array, Array]:
     88     """Returns `(left_grad_values, right_grad_values)`."""
     89     left_derivatives, right_derivatives = zip(*[
---> 90         utils.multivmap(lambda values: upwind_scheme(values, spacing, boundary_condition),
     91                         np.array([j
     92                                   for j in range(self.ndim)
     93                                   if j != i]))(values)
     94         for i, (spacing, boundary_condition) in enumerate(zip(self.spacings, self.boundary_conditions))
     95     ])
     96     return (jnp.stack(left_derivatives, -1), jnp.stack(right_derivatives, -1))

File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/finite_differences/upwind_first.py:42, in weighted_essentially_non_oscillatory(eno_order, values, spacing, boundary_condition)
     37 if eno_order == 1:
     38     return (diffs[:-1], diffs[1:])
     40 substencil_approximations = tuple(
     41     _unrolled_correlate(diffs[i:len(diffs) - eno_order + i], c)
---> 42     for (i, c) in enumerate(_diff_coefficients(eno_order)))
     43 diffs2 = diffs[1:] - diffs[:-1]
     44 smoothness_indicators = [
     45     sum(
     46         _unrolled_correlate(diffs2[i + j:len(diffs2) - eno_order + i + 1], L[j:, j])**2
     47         for j in range(eno_order - 1))
     48     for (i, L) in enumerate(np.linalg.cholesky(_smoothness_indicator_quad_form(eno_order)))
     49 ]

File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/finite_differences/upwind_first.py:167, in _diff_coefficients(k, stencil)
    164     elif k != stencil.shape[-1] - 1:
    165         raise ValueError("`k` must match `stencil.shape[-1] - 1` if both arguments are provided; got "
    166                          f"{(k, stencil.shape[-1] - 1)}.")
--> 167 return np.linalg.solve(
    168     np.diff(poly.polyvander(stencil, k), axis=-2)[..., 1:].swapaxes(-1, -2),
    169     np.eye(k)[(np.newaxis,) * (stencil.ndim - 1) + (0,)])

File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/numpy/linalg/_linalg.py:410, in solve(a, b)
    407 signature = 'DD->D' if isComplexType(t) else 'dd->d'
    408 with errstate(call=_raise_linalgerror_singular, invalid='call',
    409               over='ignore', divide='ignore', under='ignore'):
--> 410     r = gufunc(a, b, signature=signature)
    412 return wrap(r.astype(result_t, copy=False))

ValueError: solve: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (m,m),(m,n)->(m,n) (size 1 is different from 3)`

When I knock the resolution down to low, however, there is no error. I also had success running high and very high resolution calculations on the CPU only version of JAX.

Update: I started a new CPU only environment and it actually threw the same error as I was getting with the GPU one. I'm presently unsure how I ever got this to work.

My current setup:

I've also implemented a fix suggested by @mattkiim #14

Let me know if you need any more info or need me to run any tests.

gonultasbu commented 1 month ago

having the same issue

I downgraded to python==3.8, jax==4.10 and jaxlib==4.10 and error seems to be gone for now.