dask / distributed

A distributed task scheduler for Dask
https://distributed.dask.org
BSD 3-Clause "New" or "Revised" License
1.58k stars 718 forks source link

Default "dask" (de-serialization) protocol fails to preserve type. #6569

Open cjao opened 2 years ago

cjao commented 2 years ago

I'm using Dask 2022.1.0 on Python 3.8.13.

The "dask" serialization protocol has trouble preserving types. The following snippet uses the Pennylane package:

from dask.distributed.protocol import serialize, deserialize
import pennylane as qml
import pennylane.numpy as np

res = qml.numpy.random.uniform(0,2*np.pi,(2,1),requires_grad=True)
res1 = deserialize(*serialize(res))
print(type(res1))
<class 'numpy.ndarray'>

The problem seems to lie in the deserializer:

print(serialize(res))
({'sub-header': {'dtype': (0, '<f8'), 'shape': (2, 1), 'strides': (8, 8), 'writeable': [True]}, 'type': 'pennylane.numpy.tensor.tensor', 'type-serialized': b'\x80\x04\x95%\x00\x00\x00\x00\x00\x00\x00\x8c\x16pennylane.numpy.tensor\x94\x8c\x06tensor\x94\x93\x94.', 'serializer': 'dask'}, [<memory at 0x7f11076acd00>])

The "pickle" protocol works:

from dask.distributed.protocol import serialize, deserialize
import pennylane as qml
import pennylane.numpy as np

res = qml.numpy.random.uniform(0,2*np.pi,(2,1),requires_grad=True)
res1 = deserialize(*serialize(res, serializers=["pickle"]))
print(type(res1))
<class 'pennylane.numpy.tensor.tensor'>
jakirkham commented 2 years ago

Is this subclassing NumPy ndarrays or something?

There is probably additional logic that would be needed in that case. Alternatively one could override the serialization of those specific types themselves.