zeehio / parmap

Easy to use map and starmap python equivalents
Apache License 2.0
144 stars 9 forks source link

Class being sliced by parmap/pool? #3

Closed fergalm closed 9 years ago

fergalm commented 9 years ago

I'm encountering a weird problem using parmap I don't understand. I'm passing an object that extends a numpy array to a function through parmap, and one of the attributes is getting sliced off. I've boiled my code down to what I think is the simplest code that reproduces the problem. The class Working() below does not get sliced, but the class NotWorking() does.

Subclassing numpy classes is already pretty obscure, so there might not be away around this, but being able to use all my processors on my real task would be great.


import numpy as np
import parmap

class Working():

    def __init__(self, data):
        self.data = data
        self.lookup = dict()

    def parseKey(self, key):
        pass

    def setLookup(self, dim, values):
        self.lookup[dim] = values

class NotWorking(np.ndarray):
    def __new__(cls, input_array, nameDict=None):
        obj = np.asarray(input_array).view(cls)
        obj.lookup = nameDict
        if obj.lookup is None:
            obj.lookup = dict()
        return obj

    def __array_finalize__(self, obj):
        if obj is None:
            return
        self.lookup = getattr(obj, 'lookup', None)

    def parseKey(self, key):
        pass

    def setLookup(self, dim, values):
        self.lookup[dim] = values

def main():

    data = NotWorking(np.zeros((10,10)) )
    data.setLookup(0, 'a b c d e f g h i j'.split())
    data.setLookup(1, 'a b c d e f g h i j'.split())

    row = np.arange(10)

    #Single thread -- words
    f = lambda x: trivialFunc(x, data)
    map(f, row)

    #parmap, single thread -- works
    parmap.map(trivialFunc, row, data, parallel=False)

    #This fails
    assertAttributesPresent(data)
    parmap.map(trivialFunc, row, data, parallel=True)

def trivialFunc(i, data):
    assertAttributesPresent(data)

def assertAttributesPresent(data):
    assert hasattr(data, 'parseKey'), "Parse Key not present"
    assert hasattr(data, 'lookup'), "Lookup dict not present"
zeehio commented 9 years ago

Hi, thanks for reporting.

Parallelization in python is based on the serialization of the objects (using pickle).

Your NotWorking class does not have custom setstate and getstate methods and therefore it uses the methods from the parent np.ndarray class, that do not save your custom self.lookup dictionary.

Here is the implementation of your parent class for you to have an idea:

This shows how pickle does not serialize lookup:

import numpy as np
import pickle

class NotWorking(np.ndarray):
    def __new__(cls, input_array, nameDict=None):
        obj = np.asarray(input_array).view(cls)
        obj.lookup = nameDict
        if obj.lookup is None:
            obj.lookup = dict()
        return obj

    def __array_finalize__(self, obj):
        if obj is None:
            return
        self.lookup = getattr(obj, 'lookup', None)

    def parseKey(self, key):
        pass

    def setLookup(self, dim, values):
        self.lookup[dim] = values

def assertAttributesPresent(data):
    assert hasattr(data, 'parseKey'), "Parse Key not present"
    assert hasattr(data, 'lookup'), "Lookup dict not present"

def main():
    data = NotWorking(np.zeros((10,10)) )
    data.setLookup(0, 'a b c d e f g h i j'.split())
    data.setLookup(1, 'a b c d e f g h i j'.split())
    dataB = pickle.loads(pickle.dumps(data))
    assertAttributesPresent(data)
    assertAttributesPresent(dataB) # fails

Once you add the appropriate __setstate__ and __getstate__ methods to your class, they will work with pickle and with parmap too.

zeehio commented 9 years ago

I can't dig a lot further but there is also the reduce method. If you are interested in subclassing np.ndarray you may need to deal with this too.

fergalm commented 9 years ago

Thanks for your help. A bit of digging, and I seem to have found the solution, but I would never have looked in the right places without your explanation. Numpy's ndarray does indeed use reduce(), so that must be overridden. I found a solution on StackOverflow.

Adding the following two methods to NotWorking fixes the problem.

   def __reduce__(self):
        npState = np.ndarray.__reduce__(self)
        myState = npState[2] + (self.lookup,)
        return (npState[0], npState[1], myState)

    def __setstate__(self, state):
        self.lookup = state[-1]
        np.ndarray.__setstate__(self, state[0:-1])
zeehio commented 9 years ago

Pull request created in https://github.com/numpy/numpy/pull/5952 to improve the documentation on this issue