f0uriest / interpax

Interpolation and function approximation with JAX
MIT License
137 stars 11 forks source link

extrap does not take float #16

Closed AlexGKim closed 10 months ago

AlexGKim commented 10 months ago

The API for Interpolator2D says that extrap can be a float. The code returns an error with the function call len on a float.

https://github.com/f0uriest/interpax/blob/9633d303673b15ce4a1eb31252c3c2d4e1ef71d4/interpax/_spline.py#L1107

Here is code that produces the error

from interpax import Interpolator2D
import jax.numpy as jnp

x = jnp.linspace(0,10, 10)
y = jnp.linspace(0,8, 8)
z = jnp.zeros((10,8))+1.
interpol = Interpolator2D(x, y, z, extrap=0.)
interpol( 4.5, 5.3)

returns

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
File ~/miniforge3/envs/jsalt/lib/python3.11/site-packages/jax/_src/core.py:1605, in ShapedArray._len(self, ignored_tracer)
   1604 try:
-> 1605   return self.shape[0]
   1606 except IndexError as err:

IndexError: tuple index out of range

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
Cell In[13], line 8
      6 z = jnp.zeros((10,8))+1.
      7 interpol = Interpolator2D(x, y, z, extrap=0.)
----> 8 interpol( 4.5, 5.3)

File ~/miniforge3/envs/jsalt/lib/python3.11/site-packages/interpax/_spline.py:237, in Interpolator2D.__call__(self, xq, yq, dx, dy)
    222 def __call__(self, xq: jax.Array, yq: jax.Array, dx: int = 0, dy: int = 0):
    223     """Evaluate the interpolated function or its derivatives.
    224 
    225     Parameters
   (...)
    235         Interpolated values.
    236     """
--> 237     return interp2d(
    238         xq,
    239         yq,
    240         self.x,
    241         self.y,
    242         self.f,
    243         self.method,
    244         (dx, dy),
    245         self.extrap,
    246         self.period,
    247         **self.derivs,
    248     )

    [... skipping hidden 12 frame]

File ~/miniforge3/envs/jsalt/lib/python3.11/site-packages/interpax/_spline.py:648, in interp2d(xq, yq, x, y, f, method, derivative, extrap, period, **kwargs)
    646 periodx, periody = _parse_ndarg(period, 2)
    647 derivative_x, derivative_y = _parse_ndarg(derivative, 2)
--> 648 lowx, highx, lowy, highy = _parse_extrap(extrap, 2)
    650 if periodx is not None:
    651     xq, x, f, fx, fy, fxy = _make_periodic(xq, x, periodx, 0, f, fx, fy, fxy)

File ~/miniforge3/envs/jsalt/lib/python3.11/site-packages/interpax/_spline.py:1107, in _parse_extrap(extrap, n)
   1105 if isbool(extrap):  # same for lower,upper in all dimensions
   1106     return tuple(extrap for _ in range(2 * n))
-> 1107 elif len(extrap) == 2 and jnp.isscalar(extrap[0]):  # same l,h for all dimensions
   1108     return tuple(e for _ in range(n) for e in extrap)
   1109 elif len(extrap) == n and all(len(extrap[i]) == 2 for i in range(n)):

    [... skipping hidden 1 frame]

File ~/miniforge3/envs/jsalt/lib/python3.11/site-packages/jax/_src/core.py:1607, in ShapedArray._len(self, ignored_tracer)
   1605   return self.shape[0]
   1606 except IndexError as err:
-> 1607   raise TypeError("len() of unsized object") from err

TypeError: len() of unsized object
f0uriest commented 10 months ago

Thanks, this should be fixed by #18 and will be in the next release.