StanfordASL / hj_reachability

Hamilton-Jacobi reachability analysis in JAX.
MIT License
103 stars 16 forks source link

`interpolate` method from the `Grid` class raises `IndexError` #5

Closed bschnitzler closed 9 months ago

bschnitzler commented 10 months ago

When trying to interpolate an array over a Grid using the interpolate method, an error is returned if the interpolation point is beyond the upper bounds.

Expected behavior

Extrapolate away the data array to interpolate when point lies outside boundaries

Actual behavior

Raises an IndexError

Steps to reproduce the behavior

The following example code shows the issue

import hj_reachability as hj
import numpy as np

bl = np.array((0., 0.))
tr = np.array((1., 1.))
grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(hj.sets.Box(bl, tr), (10, 10))
values = np.random.random((10, 10))

grid.interpolate(values, np.array((1., 1.)))
schmrlng commented 9 months ago

Sorry for the delayed response -- thanks for finding this! Actually for grid values that are of type jax.Array instead of np.ndarray (the typical use case, at least for me) it seems that JAX's out-of-bounds indexing behavior has been masking the problem and making the buggy jnp.clip superfluous. Upon reflection, extrapolation probably shouldn't be allowed (or at least should be explicitly controlled by the user) so in addition to fixing the jnp.clip bounds I've changed the extrapolation behavior to return nans in #9.