dask / distributed

A distributed task scheduler for Dask
https://distributed.dask.org
BSD 3-Clause "New" or "Revised" License
1.58k stars 720 forks source link

Pickling error with numba gufunc #7929

Open rsarm opened 1 year ago

rsarm commented 1 year ago

I'm trying this example from the dask-examples repo. When computing the graph it gives a PicklingError

_pickle.PicklingError: Can't pickle <ufunc 'smooth'>: it's not the same object as __main__.smooth

This happens only when using distributed. I launch a cluster in a Slurm cluster via ipyparallel's become_dask. Something like this. Without using distributed the example works fine.

These are the versions of some relevant packages I'm using:

dask-2023.6.0
distributed-2023.6.0
toolz-0.12.0
cloudpickle-2.2.1
msgpack-1.0.5
numba-0.57.0

Not sure if this is a bug. Probably it's something in my environment. I found similar serialization errors but nothing recent. Only solved issues earlier than 2022.

fjetter commented 10 months ago

Sorry for the late reply. I can reproduce this. here is another reproducer (note that jit works)

from numba import guvectorize, int64, jit

@jit(nopython=True, cache=True)
def f(x, y):
    return x + y

@guvectorize([(int64[:], int64, int64[:])], '(n),()->(n)', cache=True)
def g(x, y, res):
    for i in range(x.shape[0]):
        res[i] = x[i] + y

from distributed import Client

with Client() as client:
    client.submit(f, 1, 2).result()
    print("first success")
    # The following fails
    client.submit(g, arr, 2).result()
fjetter commented 10 months ago

Sorry, I believe the error I am reproducing with my example is actually a little different. My example only fails if the cache kwarg is True.

The dask-examples version fails with a different exception

milesgranger commented 10 months ago

Looking into this, git bisect'd to https://github.com/dask/distributed/pull/7564. but I think that just brought out the issue that has always been.

It worked before b/c we didn't try to pickle it. It will also 'fail' in dask w/o distributed if one tries to pickle the output array from the numba function. ie:


import numba
import dask.array as da

@numba.guvectorize(["int8,int8[:]"], "()->()")
def double(x, out):
    out[:] = x * 2

def main():
    x = da.random.randint(0, 127, size=(10, 10, 10), chunks=('1 MB', None, None), dtype='int8')
    y = double(x)

    import pickle
    pickle.dumps(double)  # Can pickle the function directly.

    # But not when the generated function from numba is part of the dask array.
    # Fails w/ PicklingError: Can't pickle <ufunc 'double'>: it's not the same object as __main__.double
    pickle.dumps(y)  

    y.max().compute()

if __name__ == '__main__':
    main()
milesgranger commented 10 months ago

Essentially, I think someone, numba or dask, is generating a bad wrapper to the function and that's somehow brought out when attached to the graph.

For example this will raise the same error, as we're changing the location of "foo" function, pickle thinks it's on main, but it's been 'moved' to Foo.

class Foo:
    def __init__(self, func):
        self.func = func

    def __call__(self):
        return self.func()

def decorator(func):
    return Foo(func)

@decorator
def foo():
    pass

def main():
    import pickle
    pickle.dumps(foo)  # Fails w/ same error "...it's not the same object as __main__.foo"

if __name__ == '__main__':
    main()
milesgranger commented 10 months ago

xref https://github.com/dask/distributed/issues/3450

crusaderky commented 8 months ago

The problem is that a numba gufunc is a pure-python wrapper around a locally-compiled numpy ufunc. When you pickle the wrapper, you trigger serialization code which ships over the wrapped pure-python function and recompiles it upon unpickle: https://github.com/numba/numba/blob/54433810bfa60458209ed2c24f62e793b12aca6b/numba/np/ufunc/gufunc.py#L33-L46

However, the wrapper undoes itself on __call__: https://github.com/numba/numba/blob/54433810bfa60458209ed2c24f62e793b12aca6b/numba/np/ufunc/gufunc.py#L171-L172

By the time control reaches dask.array.Array.__array_ufunc__, it's too late - we've lost the reference to the GUFunc object as well as the wrapped pure-python object. I can't find a way to hack around it in our serialization layer.

Workaround 1

If you just need an elementwise operation, you can use @vectorize, which is unaffected by this issue:

@numba.vectorize(["int8(int8)"])
def double(x):
    return x * 2

x = da.random.randint(...)
y = double(x)
pickle.dumps(y)  # Works

dynamic vectorize works as well:

@numba.vectorize()
def double(x):
    return x * 2

Workaround 2

If you do need to operate on vectors, you can hack the call as follows:

@numba.guvectorize(["f8,f8[:]"], "()->()")
def double(x, out):
    out[:] = x * 2

x = da.random.randint(...)
y = x.__array_ufunc__(double, "__call__", x)
pickle.dumps(y)  # Works

The __array_ufunc__ protocol is explained in NEP 13.

Solution

This issue has been around at least since 2019 - i did some archeology and found my own ticket, never resolved, on the numba board: https://github.com/numba/numba/issues/4314

Here I'm suggesting a solution: https://github.com/numba/numba/issues/4314#issuecomment-1992103833