Open OliverColeman opened 4 years ago
It should be possible to do something like this:
def mxnet_serializer(obj):
return obj.asnumpy()
def mxnet_deserializer(value):
return mxnet.ndarray.from_numpy(value, zero_copy=True)
ray.register_custom_serializer(
mxnet.ndarray.NDArray, serializer=mxnet_serializer, deserializer=mxnet_deserializer)
Note I haven't tried this since I don't have mxnet installed. If it works for you please let us know!
Thanks @pcmoritz, that does make things neater :). However this still requires copying potentially large arrays...
The deserialization code path is actually zero copy (zero copy from plasma to numpy and then zero copy with the .from_numpy(..., zero_copy=True)
function. The serialization code path does one additional copy, which could be avoided by being more careful. The right solution would be for obj.asnumpy()
to be zero-copy, which would need to be implemented in MXNet (the developers would basically need to set the .base
object of the numpy array to the original MXNet.ndarray and point the numpy data pointer to the right memory). Do you want to open an issue on the mxnet repo and suggest that?
Would your proposed solution work when the context of the mxnet array is a GPU? This assumes that the mxnet array is only being shared between processes on the same node, but that is my use case: several workers to load image data from disk, preprocess it, then load this into the GPU before another process handles executing GPU compute on the data (and then more workers to post-process the output of this worker).
It looks like support for custom serializers was dropped in Ray 1.x. Ray now works seamlessly with MXNet arrays, however I think it's not handling them efficiently (ie by using Pickle protocol with out-of-band data the way numpy arrays are handled). Will MXNet arrays ever be supported natively by Ray? It seems inefficient to require converting between MXNet and numpy arrays to pass data around...
Would be great if Ray supported transfer/storage of
mxnet.ndarrays
without serialisation (like it does with numpy arrays). At the moment I have to convert mxnet arrays to numpy arrays and then back again to, for example, use parallelised data load-transform operations.