EconForge / interpolation.py

BSD 2-Clause "Simplified" License
123 stars 35 forks source link

`mlinterp` problems with global variables + numba #34

Closed natashawatkins closed 5 years ago

natashawatkins commented 5 years ago

Seems that mlinterp has problems when being jitted and passed global variables...

screen shot 2018-10-09 at 2 39 20 pm

Interestingly, moving grid = (x1, x2) inside works, where x1 and x2 are global variables

Also had an issue with trying to jit a function inside a function with a class argument...

screen shot 2018-10-09 at 3 06 53 pm
albop commented 5 years ago

Thank you! This seems to be numba issue. I have made a smaller example here : (btw it's easier to replicate if you copy paste code instead of image)

from numba import jit, typeof, generated_jit
import numpy as np
sig = numba.typeof(grid)

x1 = np.linspace(0,1,100)
x2 = np.linspace(0,1,100)
grid = (x1,x2)

# this works:
@jit(nopython=True)
def dosomething(x):
    return x + x1.sum()

print( dosomething(0.1) )   # -returns: 50.1
x1[0] += 10
print( dosomething(0.1) )   # still returns: 50.1
                            # (a frozen copy of x1 is made in the njitted space)

# this doesn't work: it looks like there is an issue with freezing a tuple of array
@jit(nopython=True)
def dosomething(x):
    return x + gg[0].sum()

dosomething(0.2) # fails with horrible error message

# it works with basic tuples
basic_tuple = (0.1, 0.2)
@jit(nopython=True)
def dosomething(x):
    return x + basic_tuple[1]

dosomething(0.3)

# it doesn't work with lists either
basic_tuple = [x1, x2]
@jit(nopython=True)
def dosomething(x):
    return x + basic_tuple[0].sum()
natashawatkins commented 5 years ago

Another issue I see is

class Test:

    def __init__(self,
                 x1=np.linspace(0, 1, 500)**2,
                 x2=np.linspace(0, 1, 500)):

        self.x1, self.x2 = x1, x2

tc = Test()

@njit
def interp(tc, x, y):
    x1, x2 = tc.x1, tc.x2
    grid = (x1, x2)
    return mlinterp(grid, Z, (x, y))

interp(tc, 1, 1)

Returns

---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
<ipython-input-40-944aaada9307> in <module>()
     15     return mlinterp(grid, Z, (x, y))
     16 
---> 17 interp(tc, 1, 1)

/anaconda3/lib/python3.6/site-packages/numba/dispatcher.py in _compile_for_args(self, *args, **kws)
    347                 e.patch_message(msg)
    348 
--> 349             error_rewrite(e, 'typing')
    350         except errors.UnsupportedError as e:
    351             # Something unsupported is present in the user code, add help info

/anaconda3/lib/python3.6/site-packages/numba/dispatcher.py in error_rewrite(e, issue_type)
    314                 raise e
    315             else:
--> 316                 reraise(type(e), e, None)
    317 
    318         argtypes = []

/anaconda3/lib/python3.6/site-packages/numba/six.py in reraise(tp, value, tb)
    656             value = tp()
    657         if value.__traceback__ is not tb:
--> 658             raise value.with_traceback(tb)
    659         raise value
    660 

TypingError: Failed at nopython (nopython frontend)
Internal error at <numba.typeinfer.ArgConstraint object at 0x609e450f0>:
--%<----------------------------------------------------------------------------
Traceback (most recent call last):
  File "/anaconda3/lib/python3.6/site-packages/numba/errors.py", line 577, in new_error_context
    yield
  File "/anaconda3/lib/python3.6/site-packages/numba/typeinfer.py", line 199, in __call__
    assert ty.is_precise()
AssertionError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/anaconda3/lib/python3.6/site-packages/numba/typeinfer.py", line 142, in propagate
    constraint(typeinfer)
  File "/anaconda3/lib/python3.6/site-packages/numba/typeinfer.py", line 200, in __call__
    typeinfer.add_type(self.dst, ty, loc=self.loc)
  File "/anaconda3/lib/python3.6/contextlib.py", line 99, in __exit__
    self.gen.throw(type, value, traceback)
  File "/anaconda3/lib/python3.6/site-packages/numba/errors.py", line 585, in new_error_context
    six.reraise(type(newerr), newerr, tb)
  File "/anaconda3/lib/python3.6/site-packages/numba/six.py", line 659, in reraise
    raise value
numba.errors.InternalError: 
[1] During: typing of argument at <ipython-input-40-944aaada9307> (13)
--%<----------------------------------------------------------------------------

File "<ipython-input-40-944aaada9307>", line 13:
def interp(tc, x, y):
    x1, x2 = tc.x1, tc.x2
    ^

This error may have been caused by the following argument(s):
- argument 0: cannot determine Numba type of <class '__main__.Test'>

This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.

To see Python/NumPy features supported by the latest release of Numba visit:
http://numba.pydata.org/numba-doc/dev/reference/pysupported.html
and
http://numba.pydata.org/numba-doc/dev/reference/numpysupported.html

For more information about typing errors and how to debug them visit:
http://numba.pydata.org/numba-doc/latest/user/troubleshoot.html#my-code-doesn-t-compile

If you think your code should work with Numba, please report the error message
and traceback, along with a minimal reproducer at:
https://github.com/numba/numba/issues/new

It's confusing why because we have some code that does 2d interpolation and works fine with the class - but it doesn't use tuples - maybe this is a result of the more generalised API

albop commented 5 years ago

As for the initial bug I opened an issue here: https://github.com/numba/numba/issues/3395

albop commented 5 years ago

As for the other issue, the problem is you cannot put regular python objects in nopython mode. Here if you define tc = (x1,x2) and you pass it to the function like you do it should work. Or you need to define a jitclass like in the following:


from interpolation import mlinterp
from numba import jitclass, njit

from numba.types import float64
@jitclass({'x1':float64[:], 'x2': float64[:]})
class Test:

    def __init__(self, x1, x2):

        self.x1, self.x2 = x1, x2

x1 = np.linspace(0,1,100)
x2 = np.linspace(0,1,100)**2
tc = Test(x1,x2)

Z = np.random.random((100,100))

@njit
def interp(tc, x, y):
    x1, x2 = tc.x1, tc.x2
    grid = (x1, x2)
    return mlinterp(grid, Z, (x, y))

interp(tc, 0.2, 0.8)

But there isn't much benefit compared to passing the tuple.

natashawatkins commented 5 years ago

Ah yep sorry I meant to put it in this form, which works

class Test:

    def __init__(self,
                 x1=np.linspace(0, 1, 500)**2,
                 x2=np.linspace(0, 1, 500)):

        self.x1, self.x2 = x1, x2

def return_interp(tc):

    x1, x2 = tc.x1, tc.x2
    X, Y = np.meshgrid(x1, x2)
    Z = y_func(X, Y)

    @njit
    def interp(x, y):
        grid = (x1, x2)
        return mlinterp(grid, Z, (x, y))

    return interp

tc = Test()
interp_func = return_interp(tc)
interp_func(1, 1)
natashawatkins commented 5 years ago

On upgrading to numba 0.4, I've been able to parallelise a version of 2d linear interpolation that @jstac wrote, but can't use mlinterp, I think due to the tuple being a python object

albop commented 5 years ago

I have no idea what you mean here.

natashawatkins commented 5 years ago

I tried to use prange and parallelise a loop in my code and received this error:

LoweringError: Failed in nopython mode pipeline (step: nopython mode backend)
scalar type tuple(readonly array(float64, 1d, C) x 2) given for non scalar argument #2

File "<ipython-input-25-323f926fcd3a>", line 34:
    def T(v):
        <source elided>

        for i in prange(len(w_grid)):
        ^

[1] During: lowering "id=39[LoopNest(index_variable = parfor_index.7692, range = (0, $34.3_size0.7675, 1))]{48: <ir.Block at <ipython-input-25-323f926fcd3a> (34)>}Var(parfor_index.7692, <ipython-input-25-323f926fcd3a> (34))" at <ipython-input-25-323f926fcd3a> (34)

I was able to parallelise the same code but replacing mlinterp with a jitted version of scipy's 2d linear interpolation function that @jstac provided me

albop commented 5 years ago

Ah Ok. Is the tuple grid object defined in the function out of the prange loop ? Or out of the function ?

natashawatkins commented 5 years ago

This is the function:

def T_factory(sp):

    f, g = sp.f, sp.g
    w_f, w_g = sp.w_f, sp.w_g
    β, c = sp.β, sp.c
    mc_size = sp.mc_size
    w_grid = sp.w_grid
    π_grid = sp.π_grid

    @njit
    def κ(w, π):
        """
        Updates π using Bayes' rule and the current wage observation w.
        """
        pf, pg = π * f(w), (1 - π) * g(w)
        π_new = pf / (pf + pg)

        return π_new

    @njit(parallel=True)
    def T(v):
        """
        The Bellman operator.

        """
        v_new = np.empty_like(v)
        grid = (w_grid, π_grid)
        v_func = lambda x, y: mlinterp(grid, v, (x, y))
#         v_func = lambda x, y: lininterp_2d(w_grid, π_grid, v, (x, y))

        for i in prange(len(w_grid)):
            for j in prange(len(π_grid)):
                w = w_grid[i]
                π = π_grid[j]

                v_1 = w / (1 - β)

                integral_f, integral_g = 0.0, 0.0
                for m in range(mc_size):
                    integral_f += v_func(w_f[m], κ(w_f[m], π))
                    integral_g += v_func(w_g[m], κ(w_g[m], π))
                integral = (π * integral_f + (1 - π) * integral_g) / mc_size

                v_2 = c + β * integral
                v_new[i, j] = max(v_1, v_2)

        return v_new

    return T

I could only get the code to work when I put the tuple inside T, otherwise I got the error we previously discussed

albop commented 5 years ago

Try to move grid = (w_grid, π_grid) within the inner prange loop then. Constructing tuples is almost costless, so that is not a problem. Do you have a set of values for the fields of sp so I can try it myself ?

albop commented 5 years ago

Or try lambda x, y: mlinterp((w_grid, pi_grid), v, (x, y)) instead. That should definitely work.

natashawatkins commented 5 years ago

This is the class sp

class SearchProblem:
    """
    A class to store a given parameterization of the "offer distribution
    unknown" model.

    """

    def __init__(self, 
                 β=0.95,            # Discount factor
                 c=0.3,             # Unemployment compensation
                 F_a=1, 
                 F_b=1, 
                 G_a=3, 
                 G_b=1.2,
                 w_max=1,           # Maximum wage possible
                 w_grid_size=50, 
                 π_grid_size=50,
                 mc_size=100):

        self.β, self.c, self.w_max = β, c, w_max

        self.f = beta_function_factory(F_a, F_b)
        self.g = beta_function_factory(G_a, G_b)

        self.π_min, self.π_max = 1e-3, 1-1e-3    # Avoids instability
        self.w_grid = np.linspace(0, w_max, w_grid_size)
        self.π_grid = np.linspace(self.π_min, self.π_max, π_grid_size)

        self.mc_size = mc_size

        self.w_f = np.random.beta(F_a, F_b, mc_size)
        self.w_g = np.random.beta(G_a, G_b, mc_size)

I'm just starting from a matrix v filled with 12

natashawatkins commented 5 years ago

Ah yep that worked thanks!

natashawatkins commented 5 years ago

It worked but I didn't get the same speed up from parallelisation as @jstac's code

albop commented 5 years ago

That's less of a problem. How big is the difference ?

On Wed, Oct 10, 2018 at 3:59 PM Natasha notifications@github.com wrote:

It worked but I didn't get the same speed up from parallelisation as @jstac https://github.com/jstac's code

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/EconForge/interpolation.py/issues/34#issuecomment-428582955, or mute the thread https://github.com/notifications/unsubscribe-auth/AAQ5KaQq7bhDmCg9Pjilxis-13459a8Vks5ujf1cgaJpZM4XT8qk .

natashawatkins commented 5 years ago

@jstac's went from 2.5 sec to 1sec, while mlinterp went from 3.8 sec to 2 sec

albop commented 5 years ago

Doesn't look like such a big difference to me. More or less a division of the speed by 2. More bothersome is the non-parallelized speed difference. It will vanish away eventually (since mlinterp generats its own code, it could in principle generates @jstac's one), but that will take a little bit of time.