I think the function interpolate in grid.py is not safe against out of grid bound error.
For instance, domain_lo or domain_hi can be out of the grid shape range.
I made the following changes and it fixed the bugs, but not sure if this is the ideal change.
def interpolate(self, values, state):
"""Interpolates `values` (possibly multidimensional per node) defined over the grid at the given `state`."""
# check wehter state is in the domain
out_of_range_status = jnp.logical_or(state < self.domain.lo, state > self.domain.hi)
# check if any of out_of_range_status is True and is_periodic_dim is False
if jnp.any(jnp.logical_and(out_of_range_status, jnp.logical_not(self._is_periodic_dim))):
raise ValueError("state is out of the domain")
position = (state - self.domain.lo) / jnp.array(self.spacings)
index_lo = jnp.floor(position).astype(jnp.int32)
# for index_lo that is same with grid shape-1, set it to the last index-1
index_lo = jnp.where(index_lo >= np.array(self.shape)-1, np.array(self.shape) - 2, index_lo)
index_hi = index_lo + 1
weight_hi = position - index_lo
weight_lo = 1 - weight_hi
index_lo, index_hi = tuple(
jnp.where(self._is_periodic_dim, index % np.array(self.shape), jnp.clip(index, 0, np.array(self.shape)))
for index in (index_lo, index_hi))
weight = functools.reduce(lambda x, y: x * y, jnp.ix_(*jnp.stack([weight_lo, weight_hi], -1)))
# TODO: Double-check numerical stability here and/or switch to `tuple`s and `itertools.product` for clarity.
return jnp.sum(
weight[(...,) + (np.newaxis,) * (values.ndim - self.ndim)] *
values[jnp.ix_(*jnp.stack([index_lo, index_hi], -1))], list(range(self.ndim)))
Hi Ed,
I think the function
interpolate
ingrid.py
is not safe against out of grid bound error. For instance,domain_lo
ordomain_hi
can be out of the grid shape range. I made the following changes and it fixed the bugs, but not sure if this is the ideal change.