adam-coogan / jaxinterp2d

Bilinear interpolation on grids with jax
MIT License
12 stars 3 forks source link

crash if x,y have not the same number of elements #3

Open jecampagne opened 2 years ago

jecampagne commented 2 years ago

Hi, I have a crash if "x" and "y" have different sizes I got for instance

xp=jnp.linspace(-1.,1.,20)  # N
yp=jnp.linspace(-1.,1.,10)  # M

X, Y = np.meshgrid(xp, yp)

def func(x,y):
    return jnp.exp(-0.5*(x**2+y**2))

Z = func(X,Y)
plt.contourf(xp,yp,Z)  #### ici fp MxN

fp = Z.T    # N x M

x_star = jnp.array([-0.5,0.1])
y_star = jnp.array([0.7,-0.7,0.0])

ut.interp2d(x_star, y_star,xp,yp,fp)     # ut is just a link to your code

then

--------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
    [... skipping hidden 1 frame]

File python3.8/site-packages/jax/_src/util.py:219, in cache.<locals>.wrap.<locals>.wrapper(*args, **kwargs)
    218 else:
--> 219   return cached(config._trace_context(), *args, **kwargs)

File /python3.8/site-packages/jax/_src/util.py:212, in cache.<locals>.wrap.<locals>.cached(_, *args, **kwargs)
    210 @functools.lru_cache(max_size)
    211 def cached(_, *args, **kwargs):
--> 212   return f(*args, **kwargs)

File /python3.8/site-packages/jax/_src/lax/lax.py:126, in _broadcast_shapes_cached(*shapes)
    124 @cache()
    125 def _broadcast_shapes_cached(*shapes: Tuple[int, ...]) -> Tuple[int, ...]:
--> 126   return _broadcast_shapes_uncached(*shapes)

File /python3.8/site-packages/jax/_src/lax/lax.py:142, in _broadcast_shapes_uncached(*shapes)
    141 if result_shape is None:
--> 142   raise ValueError("Incompatible shapes for broadcasting: {}"
    143                    .format(tuple(shape_list)))
    144 return result_shape

ValueError: Incompatible shapes for broadcasting: ((2,), (3,))

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)

     55 # Using Wikipedia's notation ([https://en.wikipedia.org/wiki/Bilinear_interpolation)](https://en.wikipedia.org/wiki/Bilinear_interpolation)%3C/span%3E)
---> 56 z_11 = zp[ix - 1, iy - 1]
     57 z_21 = zp[ix, iy - 1]
     58 z_12 = zp[ix - 1, iy]

File /jax/_src/numpy/lax_numpy.py:3544, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   3541   return _getslice(arr, start, stop)
   3543 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
-> 3544 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   3545                unique_indices, mode, fill_value)

File/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:3553, in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
   3550 def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   3551             unique_indices, mode, fill_value):
   3552   idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
-> 3553   indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
   3554   y = arr
   3556   if fill_value is not None:

File /python3.8/site-packages/jax/_src/numpy/lax_numpy.py:3723, in _index_to_gather(x_shape, idx, normalize_indices)
   3716 for idx_pos, i in enumerate(idx):
   3717   # Handle the advanced indices here if:
   3718   # * the advanced indices were not contiguous and we are the start.
   3719   # * we are at the position of the first advanced index.
   3720   if (advanced_indexes is not None and
   3721       (advanced_axes_are_contiguous and idx_pos == idx_advanced_axes[0] or
   3722        not advanced_axes_are_contiguous and idx_pos == 0)):
-> 3723     advanced_indexes = broadcast_arrays(*advanced_indexes)
   3724     shape = advanced_indexes[0].shape
   3725     ndim = len(shape)

    [... skipping hidden 14 frame]

File /python3.8/site-packages/jax/_src/numpy/util.py:338, in _broadcast_arrays(*args)
    334 if not shapes or all(core.symbolic_equal_shape(shapes[0], s) for s in shapes):
    335   # TODO(mattjj): remove the array(arg) here
    336   return [arg if isinstance(arg, ndarray) or np.isscalar(arg) else _asarray(arg)
    337           for arg in args]
--> 338 result_shape = lax.broadcast_shapes(*shapes)
    339 return [_broadcast_to(arg, result_shape) for arg in args]

    [... skipping hidden 1 frame]

File/python3.8/site-packages/jax/_src/lax/lax.py:142, in _broadcast_shapes_uncached(*shapes)
    140 result_shape = _try_broadcast_shapes(shape_list)
    141 if result_shape is None:
--> 142   raise ValueError("Incompatible shapes for broadcasting: {}"
    143                    .format(tuple(shape_list)))
    144 return result_shape

ValueError: Incompatible shapes for broadcasting: ((2,), (3,))
jecampagne commented 2 years ago

Well it might be that x,y with different size is not so well defined as fpi,j = f(x_i, y_j) does is means to run over all i at fixed j => n_i x n_j values or when ni=nj then fpi,i = f(x_i,y_i) only i: over n_i => n_i values

Well... the two use-cases can happen.

YouJiacheng commented 2 years ago

@jecampagne xp and yp can have different size. x_star and y_star should have same size, since point i coordinate is (xstar[i], ystar[i])! If you want mesh semantic, you can use meshgrid and flatten output.

x_star = jnp.array([-0.5,0.1])
y_star = jnp.array([0.7,-0.7,0.0])
x_star, y_star = np.meshgrid(x_star, y_star)
x_star = x_star.reshape((-1,))
y_star = y_star.reshape((-1,))
jecampagne commented 2 years ago

Hum, sounds good :)