Closed oyamad closed 4 months ago
@oyamad wow thank you - love your work.
@albop would you be able to review and release an updated version if you're happy with this?
Thank you very much @oyamad and all others involved (@mmcky and @kp992). The edits look clean enough. I'll merge it now and make a new release.
Hi @oyamad, thank you for the PR.
There are some tests which actually need some modifications to detect the failures. See the following diff:
diff --git a/interpolation/multilinear/tests/test_multilinear.py b/interpolation/multilinear/tests/test_multilinear.py
index 1139b72..4f8766e 100644
--- a/interpolation/multilinear/tests/test_multilinear.py
+++ b/interpolation/multilinear/tests/test_multilinear.py
@@ -115,8 +115,9 @@ def test_mlinterp():
pp = np.random.random((2000, 2))
res0 = mlinterp((x1, x2), y, pp)
+ assert res0 is not None
res0 = mlinterp((x1, x2), y, (0.1, 0.2))
-
+ assert res0 is not None
def test_multilinear():
# flat flexible api
@@ -124,6 +125,7 @@ def test_multilinear():
for t in tests:
tt = [typeof(e) for e in t]
rr = interp(*t)
+ assert rr is not None
try:
print(f"{tt}: {rr.shape}")
Running pytest:
% pytest interpolation/multilinear
============================= test session starts ==============================
platform darwin -- Python 3.10.12, pytest-8.1.0, pluggy-1.4.0
rootdir: /Users/kpl/repos/interpolation.py
configfile: pyproject.toml
plugins: anyio-4.0.0
collected 3 items
interpolation/multilinear/tests/test_multilinear.py .FF [100%]
=================================== FAILURES ===================================
________________________________ test_mlinterp _________________________________
def test_mlinterp():
# simple multilinear interpolation api
import numpy as np
from interpolation import mlinterp
# from interpolation.multilinear.mlinterp import mlininterp, mlininterp_vec
x1 = np.linspace(0, 1, 10)
x2 = np.linspace(0, 1, 20)
y = np.random.random((10, 20))
z1 = np.linspace(0, 1, 30)
z2 = np.linspace(0, 1, 30)
pp = np.random.random((2000, 2))
res0 = mlinterp((x1, x2), y, pp)
> assert res0 is not None
E assert None is not None
interpolation/multilinear/tests/test_multilinear.py:118: AssertionError
_______________________________ test_multilinear _______________________________
def test_multilinear():
# flat flexible api
for t in tests:
tt = [typeof(e) for e in t]
rr = interp(*t)
> assert rr is not None
E assert None is not None
interpolation/multilinear/tests/test_multilinear.py:128: AssertionError
=========================== short test summary info ============================
FAILED interpolation/multilinear/tests/test_multilinear.py::test_mlinterp - assert None is not None
FAILED interpolation/multilinear/tests/test_multilinear.py::test_multilinear - assert None is not None
========================= 2 failed, 1 passed in 0.71s ==========================
@kp992 Good catch. Those functions with @overload
(in particular, the exported interp
and mlinterp
) have to be called in a njit
-ted function: https://github.com/EconForge/interpolation.py/blob/705cbce6c37f8605e00d503a4d7ff9516512ce78/interpolation/multilinear/mlinterp.py#L44-L49 https://github.com/EconForge/interpolation.py/blob/705cbce6c37f8605e00d503a4d7ff9516512ce78/interpolation/multilinear/mlinterp.py#L220-L225
Maybe it is likely that some user code that depends on this library calls these from a non njit
-ted function? Then this PR will break that code...
Do we have to do like this?
def _interp(*args):
pass
@overload(_interp)
def ol_interp(*args):
aa = args[0].types
it = detect_types(aa)
if it.d == 1 and it.eval == "point":
it = itt(it.d, it.values, "cartesian")
source = make_mlinterp(it, "__mlinterp")
import ast
tree = ast.parse(source)
code = compile(tree, "<string>", "exec")
eval(code, globals())
return __mlinterp
@njit
def interp(*args):
return _interp(args)
Ah, I was too quick to merge the PR. @oyamad : what you show is excactly the pattern I used for eval_spline and some other functions.
@njit
def fun(*args):
return fun_(*args)
class fun_():
pass
@overload(fun_)
def o_fun(a,b):
def body(a,b):
return (a+b)
return body
I just spent 10 min googling to find if there was a better way but to no avail. It is mentioned in this discussion https://github.com/numba/numba/issues/8897 but I don' t see any better solution
Resolve #110
All tests pass on my environment (with Python 3.12.2, Numba 0.59.0).