I've recently installed this library, and have run into an issue when trying to run the starter code with GPU-accelerated Jax. For all solver resolutions except low, I get an operand dimension mismatch error thrown from diff_coefficients. Here's the traceback from trying to run hj.step on the unmodified quickstart.ipynb:
`---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[4], line 3
1 time = 0.
2 target_time = -2.8
----> 3 target_values = hj.step(solver_settings, dynamics, grid, time, values, target_time)
[... skipping hidden 11 frame]
File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/solver.py:76, in step(solver_settings, dynamics, grid, time, values, target_time, progress_bar)
73 bar.update_to(jnp.abs(t - bar.reference_time))
74 return t, v
---> 76 return jax.lax.while_loop(lambda time_values: jnp.abs(target_time - time_values[0]) > 0, sub_step,
77 (time, values))[1]
[... skipping hidden 9 frame]
File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/solver.py:71, in step.<locals>.sub_step(time_values)
70 def sub_step(time_values):
---> 71 t, v = solver_settings.time_integrator(solver_settings, dynamics, grid, *time_values, target_time)
72 if bar is not False:
73 bar.update_to(jnp.abs(t - bar.reference_time))
File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/time_integration.py:50, in third_order_total_variation_diminishing_runge_kutta(solver_settings, dynamics, grid, time, values, target_time)
49 def third_order_total_variation_diminishing_runge_kutta(solver_settings, dynamics, grid, time, values, target_time):
---> 50 time_1, values_1 = euler_step(solver_settings, dynamics, grid, time, values, max_time_step=target_time - time)
51 time_step = time_1 - time
52 _, values_2 = euler_step(solver_settings, dynamics, grid, time_1, values_1, time_step)
[... skipping hidden 11 frame]
File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/time_integration.py:21, in euler_step(solver_settings, dynamics, grid, time, values, time_step, max_time_step)
19 time_direction = jnp.sign(max_time_step) if time_step is None else jnp.sign(time_step)
20 signed_hamiltonian = lambda *args, **kwargs: time_direction * dynamics.hamiltonian(*args, **kwargs)
---> 21 left_grad_values, right_grad_values = grid.upwind_grad_values(solver_settings.upwind_scheme, values)
22 dissipation_coefficients = solver_settings.artificial_dissipation_scheme(dynamics.partial_max_magnitudes,
23 grid.states, time, values,
24 left_grad_values, right_grad_values)
25 dvalues_dt = -solver_settings.hamiltonian_postprocessor(time_direction * utils.multivmap(
26 lambda state, value, left_grad_value, right_grad_value, dissipation_coefficients:
27 (lax_friedrichs_numerical_hamiltonian(signed_hamiltonian, state, time, value,
28 left_grad_value, right_grad_value, dissipation_coefficients)),
29 np.arange(grid.ndim))(grid.states, values, left_grad_values, right_grad_values, dissipation_coefficients))
File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/grid.py:89, in Grid.upwind_grad_values(self, upwind_scheme, values)
87 def upwind_grad_values(self, upwind_scheme: Callable, values: Array) -> Tuple[Array, Array]:
88 """Returns `(left_grad_values, right_grad_values)`."""
---> 89 left_derivatives, right_derivatives = zip(*[
90 utils.multivmap(lambda values: upwind_scheme(values, spacing, boundary_condition),
91 np.array([j
92 for j in range(self.ndim)
93 if j != i]))(values)
94 for i, (spacing, boundary_condition) in enumerate(zip(self.spacings, self.boundary_conditions))
95 ])
96 return (jnp.stack(left_derivatives, -1), jnp.stack(right_derivatives, -1))
File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/grid.py:90, in <listcomp>(.0)
87 def upwind_grad_values(self, upwind_scheme: Callable, values: Array) -> Tuple[Array, Array]:
88 """Returns `(left_grad_values, right_grad_values)`."""
89 left_derivatives, right_derivatives = zip(*[
---> 90 utils.multivmap(lambda values: upwind_scheme(values, spacing, boundary_condition),
91 np.array([j
92 for j in range(self.ndim)
93 if j != i]))(values)
94 for i, (spacing, boundary_condition) in enumerate(zip(self.spacings, self.boundary_conditions))
95 ])
96 return (jnp.stack(left_derivatives, -1), jnp.stack(right_derivatives, -1))
[... skipping hidden 6 frame]
File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/grid.py:90, in Grid.upwind_grad_values.<locals>.<lambda>(values)
87 def upwind_grad_values(self, upwind_scheme: Callable, values: Array) -> Tuple[Array, Array]:
88 """Returns `(left_grad_values, right_grad_values)`."""
89 left_derivatives, right_derivatives = zip(*[
---> 90 utils.multivmap(lambda values: upwind_scheme(values, spacing, boundary_condition),
91 np.array([j
92 for j in range(self.ndim)
93 if j != i]))(values)
94 for i, (spacing, boundary_condition) in enumerate(zip(self.spacings, self.boundary_conditions))
95 ])
96 return (jnp.stack(left_derivatives, -1), jnp.stack(right_derivatives, -1))
File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/finite_differences/upwind_first.py:42, in weighted_essentially_non_oscillatory(eno_order, values, spacing, boundary_condition)
37 if eno_order == 1:
38 return (diffs[:-1], diffs[1:])
40 substencil_approximations = tuple(
41 _unrolled_correlate(diffs[i:len(diffs) - eno_order + i], c)
---> 42 for (i, c) in enumerate(_diff_coefficients(eno_order)))
43 diffs2 = diffs[1:] - diffs[:-1]
44 smoothness_indicators = [
45 sum(
46 _unrolled_correlate(diffs2[i + j:len(diffs2) - eno_order + i + 1], L[j:, j])**2
47 for j in range(eno_order - 1))
48 for (i, L) in enumerate(np.linalg.cholesky(_smoothness_indicator_quad_form(eno_order)))
49 ]
File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/hj_reachability/finite_differences/upwind_first.py:167, in _diff_coefficients(k, stencil)
164 elif k != stencil.shape[-1] - 1:
165 raise ValueError("`k` must match `stencil.shape[-1] - 1` if both arguments are provided; got "
166 f"{(k, stencil.shape[-1] - 1)}.")
--> 167 return np.linalg.solve(
168 np.diff(poly.polyvander(stencil, k), axis=-2)[..., 1:].swapaxes(-1, -2),
169 np.eye(k)[(np.newaxis,) * (stencil.ndim - 1) + (0,)])
File ~/reu/python_implementations/.venv/lib/python3.10/site-packages/numpy/linalg/_linalg.py:410, in solve(a, b)
407 signature = 'DD->D' if isComplexType(t) else 'dd->d'
408 with errstate(call=_raise_linalgerror_singular, invalid='call',
409 over='ignore', divide='ignore', under='ignore'):
--> 410 r = gufunc(a, b, signature=signature)
412 return wrap(r.astype(result_t, copy=False))
ValueError: solve: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (m,m),(m,n)->(m,n) (size 1 is different from 3)`
When I knock the resolution down to low, however, there is no error. I also had success running high and very high resolution calculations on the CPU only version of JAX.
Update: I started a new CPU only environment and it actually threw the same error as I was getting with the GPU one. I'm presently unsure how I ever got this to work.
My current setup:
Ubuntu 20.04 (Focal Fossa)
Python 3.10
Jax / Jaxlib 4.30
Cuda 12.2
I've also implemented a fix suggested by @mattkiim #14
Let me know if you need any more info or need me to run any tests.
I've recently installed this library, and have run into an issue when trying to run the starter code
with GPU-accelerated Jax. For all solver resolutions except low, I get an operand dimension mismatch error thrown fromdiff_coefficients
. Here's the traceback from trying to runhj.step
on the unmodifiedquickstart.ipynb
:When I knock the resolution down to low, however, there is no error.
I also had success running high and very high resolution calculations on the CPU only version of JAX.Update: I started a new CPU only environment and it actually threw the same error as I was getting with the GPU one. I'm presently unsure how I ever got this to work.
My current setup:
I've also implemented a fix suggested by @mattkiim #14
Let me know if you need any more info or need me to run any tests.