StanfordASL / hj_reachability

Hamilton-Jacobi reachability analysis in JAX.
MIT License
103 stars 16 forks source link

Grid bound handling in interpolate #13

Open ChoiJangho opened 4 months ago

ChoiJangho commented 4 months ago

Hi Ed,

I think the function interpolate in grid.py is not safe against out of grid bound error. For instance, domain_lo or domain_hi can be out of the grid shape range. I made the following changes and it fixed the bugs, but not sure if this is the ideal change.

    def interpolate(self, values, state):
        """Interpolates `values` (possibly multidimensional per node) defined over the grid at the given `state`."""
        # check wehter state is in the domain
        out_of_range_status = jnp.logical_or(state < self.domain.lo, state > self.domain.hi)
        # check if any of out_of_range_status is True and is_periodic_dim is False
        if jnp.any(jnp.logical_and(out_of_range_status, jnp.logical_not(self._is_periodic_dim))):
            raise ValueError("state is out of the domain")

        position = (state - self.domain.lo) / jnp.array(self.spacings)
        index_lo = jnp.floor(position).astype(jnp.int32)
        # for index_lo that is same with grid shape-1, set it to the last index-1
        index_lo = jnp.where(index_lo >= np.array(self.shape)-1, np.array(self.shape) - 2, index_lo)
        index_hi = index_lo + 1

        weight_hi = position - index_lo
        weight_lo = 1 - weight_hi
        index_lo, index_hi = tuple(
            jnp.where(self._is_periodic_dim, index % np.array(self.shape), jnp.clip(index, 0, np.array(self.shape)))
            for index in (index_lo, index_hi))
        weight = functools.reduce(lambda x, y: x * y, jnp.ix_(*jnp.stack([weight_lo, weight_hi], -1)))
        # TODO: Double-check numerical stability here and/or switch to `tuple`s and `itertools.product` for clarity.
        return jnp.sum(
            weight[(...,) + (np.newaxis,) * (values.ndim - self.ndim)] *
            values[jnp.ix_(*jnp.stack([index_lo, index_hi], -1))], list(range(self.ndim)))