StanfordASL / hj_reachability

Hamilton-Jacobi reachability analysis in JAX.
MIT License
103 stars 16 forks source link

getting `nan` in solution for a double integrator #4

Closed dev10110 closed 9 months ago

dev10110 commented 10 months ago

Hi,

I'm trying to use this library to solve for the backwards reachable tube of an obstacle. I was able to successfully run the examples/air3d.ipynb but I am having trouble defining my own system. Things keep going to nan.

The dynamics are $\ddot x = u$ and the objective is to avoid a wall at $x_1 = 1.5$.

Here is my code:

import jax
import jax.numpy as jnp
import numpy as np

from IPython.display import HTML
import matplotlib.animation as anim
import matplotlib.pyplot as plt
import plotly.graph_objects as go

import hj_reachability as hj

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
# 'gpu' (im running this on a x86 with CUDA)

class DoubleIntegrator(hj.ControlAndDisturbanceAffineDynamics):

    def __init__(self, 
                 max_acceleration=2.0,
                 max_disturbance=0.0,
                 control_mode="min",
                 disturbance_mode="max",
                 control_space = None,
                 disturbance_space = None):

        if control_space is None:
            control_space = hj.sets.Box(jnp.array([-max_acceleration]),
                                        jnp.array([max_acceleration]) )

        if disturbance_space is None:
            disturbance_space = hj.sets.Ball(jnp.zeros(1), max_disturbance)

        super().__init__(control_mode, disturbance_mode, control_space, disturbance_space)

    def open_loop_dynamics(self, state, time):
        v = state[1]
        return jnp.array([v, 0.])

    def control_jacobian(self, state, time):
        return jnp.array( [
            [0.,],
            [1.,]
        ])

    def disturbance_jacobian(self, state, time):
        return jnp.array( [
            [0.,],
            [1.,]
        ])

dynamics = DoubleIntegrator()

grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(hj.sets.Box(lo=np.array([-2., -2.]),
                                                                           hi=np.array([2., 2.])),
                                                               (40, 40,), periodic_dims=[])

# the objective is to avoid a wall that is at 1.5 meters
values = (1.5 - grid.states[:, :, 0])

grid.states.shape # == (40,40,2)
values.shape # == (40, 40)

solver_settings = hj.SolverSettings.with_accuracy("very_high", 
                                                  hamiltonian_postprocessor=hj.solver.backwards_reachable_tube)
time = 0.
target_time = -1.0
target_values = hj.step(solver_settings, dynamics, grid, time, values, target_time)

any advice on how to debug this, or where I've gone wrong?

schmrlng commented 9 months ago

Sorry for the delayed response -- I haven't been keeping up with my GitHub notifications it seems. In case you're still working on this (or JAX nan problems in general), there's the jax_debug_nans configuration option which can help you root cause the issue. The jax.lax.while_loop in hj.step obfuscates things a bit but it's basically just wrapping

hj.time_integration.euler_step(solver_settings, dynamics, grid, time, values, target_time)

which in this case (with that debug option on, running on CPU on my machine) yields a stack trace like

...
File ~/Dropbox/code/hj_reachability/hj_reachability/dynamics.py:43, in hamiltonian()
     42 del value  # unused
---> 43 control, disturbance = self.optimal_control_and_disturbance(state, time, grad_value)
     44 return grad_value @ self(state, control, disturbance, time)

File ~/Dropbox/code/hj_reachability/hj_reachability/dynamics.py:80, in optimal_control_and_disturbance()
     78     disturbance_direction = -disturbance_direction
     79 return (self.control_space.extreme_point(control_direction),
---> 80         self.disturbance_space.extreme_point(disturbance_direction))

File ~/Dropbox/code/hj_reachability/hj_reachability/sets.py:66, in extreme_point()
     65 """Computes the point `x` in the set such that the dot product `x @ direction` is greatest."""
---> 66 return self.center + self.radius * direction / jnp.linalg.norm(direction)
...
FloatingPointError: invalid value (nan) encountered in jit(div)

So basically it's a bug with not-very-careful normalization in the implementation of hj.sets.Ball, fixed in #8.

dev10110 commented 9 months ago

oh wow thats awesome! thanks for providing the fix too :)