EconForge / interpolation.py

BSD 2-Clause "Simplified" License
123 stars 35 forks source link

Replace `generated_jit` with `overload` #112

Closed oyamad closed 4 months ago

oyamad commented 4 months ago

Resolve #110

All tests pass on my environment (with Python 3.12.2, Numba 0.59.0).

mmcky commented 4 months ago

@oyamad wow thank you - love your work.

mmcky commented 4 months ago

@albop would you be able to review and release an updated version if you're happy with this?

albop commented 4 months ago

Thank you very much @oyamad and all others involved (@mmcky and @kp992). The edits look clean enough. I'll merge it now and make a new release.

kp992 commented 4 months ago

Hi @oyamad, thank you for the PR.

There are some tests which actually need some modifications to detect the failures. See the following diff:

diff --git a/interpolation/multilinear/tests/test_multilinear.py b/interpolation/multilinear/tests/test_multilinear.py
index 1139b72..4f8766e 100644
--- a/interpolation/multilinear/tests/test_multilinear.py
+++ b/interpolation/multilinear/tests/test_multilinear.py
@@ -115,8 +115,9 @@ def test_mlinterp():
     pp = np.random.random((2000, 2))

     res0 = mlinterp((x1, x2), y, pp)
+    assert res0 is not None
     res0 = mlinterp((x1, x2), y, (0.1, 0.2))
-
+    assert res0 is not None

 def test_multilinear():
     # flat flexible api
@@ -124,6 +125,7 @@ def test_multilinear():
     for t in tests:
         tt = [typeof(e) for e in t]
         rr = interp(*t)
+        assert rr is not None

         try:
             print(f"{tt}: {rr.shape}")

Running pytest:

% pytest interpolation/multilinear
============================= test session starts ==============================
platform darwin -- Python 3.10.12, pytest-8.1.0, pluggy-1.4.0
rootdir: /Users/kpl/repos/interpolation.py
configfile: pyproject.toml
plugins: anyio-4.0.0
collected 3 items                                                              

interpolation/multilinear/tests/test_multilinear.py .FF                  [100%]

=================================== FAILURES ===================================
________________________________ test_mlinterp _________________________________

    def test_mlinterp():
        # simple multilinear interpolation api

        import numpy as np
        from interpolation import mlinterp

        # from interpolation.multilinear.mlinterp import mlininterp, mlininterp_vec
        x1 = np.linspace(0, 1, 10)
        x2 = np.linspace(0, 1, 20)
        y = np.random.random((10, 20))

        z1 = np.linspace(0, 1, 30)
        z2 = np.linspace(0, 1, 30)

        pp = np.random.random((2000, 2))

        res0 = mlinterp((x1, x2), y, pp)
>       assert res0 is not None
E       assert None is not None

interpolation/multilinear/tests/test_multilinear.py:118: AssertionError
_______________________________ test_multilinear _______________________________

    def test_multilinear():
        # flat flexible api

        for t in tests:
            tt = [typeof(e) for e in t]
            rr = interp(*t)
>           assert rr is not None
E           assert None is not None

interpolation/multilinear/tests/test_multilinear.py:128: AssertionError
=========================== short test summary info ============================
FAILED interpolation/multilinear/tests/test_multilinear.py::test_mlinterp - assert None is not None
FAILED interpolation/multilinear/tests/test_multilinear.py::test_multilinear - assert None is not None
========================= 2 failed, 1 passed in 0.71s ==========================
oyamad commented 4 months ago

@kp992 Good catch. Those functions with @overload (in particular, the exported interp and mlinterp) have to be called in a njit-ted function: https://github.com/EconForge/interpolation.py/blob/705cbce6c37f8605e00d503a4d7ff9516512ce78/interpolation/multilinear/mlinterp.py#L44-L49 https://github.com/EconForge/interpolation.py/blob/705cbce6c37f8605e00d503a4d7ff9516512ce78/interpolation/multilinear/mlinterp.py#L220-L225

Maybe it is likely that some user code that depends on this library calls these from a non njit-ted function? Then this PR will break that code...

oyamad commented 4 months ago

Do we have to do like this?

def _interp(*args):
    pass

@overload(_interp)
def ol_interp(*args):
    aa = args[0].types

    it = detect_types(aa)
    if it.d == 1 and it.eval == "point":
        it = itt(it.d, it.values, "cartesian")
    source = make_mlinterp(it, "__mlinterp")
    import ast

    tree = ast.parse(source)
    code = compile(tree, "<string>", "exec")
    eval(code, globals())
    return __mlinterp

@njit
def interp(*args):
    return _interp(args)
albop commented 4 months ago

Ah, I was too quick to merge the PR. @oyamad : what you show is excactly the pattern I used for eval_spline and some other functions.

@njit
def fun(*args):

    return fun_(*args)

class fun_():
    pass

@overload(fun_)
def o_fun(a,b):

    def body(a,b):
        return (a+b)

    return body

I just spent 10 min googling to find if there was a better way but to no avail. It is mentioned in this discussion https://github.com/numba/numba/issues/8897 but I don' t see any better solution