Open pentschev opened 11 months ago
This failure is seen consistently in DGX tests, which are not observed in CI due to the lack of UCX testing for transports other than TCP.
We can monkey patch past this error in dask-cuda, by patching protocol.loads
and comm.utils.from_frames
:
from dask_cuda.initialize import initialize
from distributed import Client
import cupy
import dask.array as da
import msgpack
import dask
from distributed.protocol.core import msgpack_decode_default, msgpack_opts, merge_and_deserialize, decompress, logger, Serialized
from distributed.comm.utils import nbytes, offload, OFFLOAD_THRESHOLD
import distributed.protocol
import distributed.comm.utils
import pickle
async def from_frames(frames, deserialize=True, deserializers=None, allow_offload=True):
"""
Unserialize a list of Distributed protocol frames.
"""
size = False
def _from_frames():
try:
# print(protocol.loads.__code__)
return loads(
frames, deserialize=deserialize, deserializers=deserializers
)
except EOFError:
if size > 1000:
datastr = "[too large to display]"
else:
datastr = frames
# Aid diagnosing
logger.error("truncated data stream (%d bytes): %s", size, datastr)
raise
if allow_offload and deserialize and OFFLOAD_THRESHOLD:
size = sum(map(nbytes, frames))
if allow_offload and deserialize and OFFLOAD_THRESHOLD and size > OFFLOAD_THRESHOLD:
res = await offload(_from_frames)
else:
res = _from_frames()
return res
def loads(frames, deserialize=True, deserializers=None):
"""Transform bytestream back into Python value"""
allow_pickle = dask.config.get("distributed.scheduler.pickle")
try:
def _decode_default(obj):
offset = obj.get("__Serialized__", 0)
if offset > 0:
sub_header = msgpack.loads(
frames[offset],
object_hook=msgpack_decode_default,
use_list=False,
**msgpack_opts,
)
offset += 1
sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
if deserialize:
if "compression" in sub_header:
sub_frames = decompress(sub_header, sub_frames)
return merge_and_deserialize(
sub_header, sub_frames, deserializers=deserializers
)
else:
return Serialized(sub_header, sub_frames)
offset = obj.get("__Pickled__", 0)
if offset > 0:
sub_header = msgpack.loads(frames[offset])
offset += 1
sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
if "compression" in sub_header:
sub_frames = decompress(sub_header, sub_frames)
if allow_pickle:
return pickle.loads(sub_header["pickled-obj"], buffers=sub_frames)
else:
raise ValueError(
"Unpickle on the Scheduler isn't allowed, set `distributed.scheduler.pickle=true`"
)
return msgpack_decode_default(obj)
return msgpack.loads(
frames[0], object_hook=_decode_default, use_list=False, **msgpack_opts
)
except Exception:
logger.critical("Failed to deserialize", exc_info=True)
raise
distributed.protocol.loads = loads
distributed.protocol.core.loads = loads
distributed.comm.utils.from_frames = from_frames
if __name__ == "__main__":
cluster = LocalCUDACluster(
protocol="ucx",
rmm_pool_size="1 GiB",
)
client = Client(cluster)
print(da.from_array(cupy.arange(10000), chunks=(1000,)).sum().compute())
https://github.com/rapidsai/dask-cuda/pull/1247 should fix this issue for RAPIDS 23.10 and allow us to pin Dask/Distributed 2023.9.2 as planned. The proper solutions must land in Distributed via https://github.com/dask/distributed/pull/8216, once the Distributed fix is in and we unpin 2023.9.2 in branch-23.12
, we should be able to revert https://github.com/rapidsai/dask-cuda/pull/1247.
Changes from https://github.com/rapidsai/dask-cuda/pull/1247 are being reverted in https://github.com/rapidsai/dask-cuda/pull/1256, which will not be required anymore once Dask/Distributed are unpinned for 23.12.
The following snippet currently fails in Dask-CUDA if we use
protocol="ucx"
:Reproducer and output
```python In [1]: from dask_cuda import LocalCUDACluster ...: from dask_cuda.initialize import initialize ...: from distributed import Client ...: ...: import cupy ...: import dask.array as da ...: ...: ...: cluster = LocalCUDACluster( ...: protocol="ucx", ...: interface="ib0", ...: rmm_pool_size="1 GiB", ...: ) ...: client = Client(cluster) In [2]: res = da.from_array(cupy.arange(10000), chunks=(1000,)) ...: res.sum().compute() 2023-09-26 12:25:44,887 - distributed.protocol.core - CRITICAL - Failed to deserialize Traceback (most recent call last): File "/datasets/pentschev/miniconda3/envs/rn-230925/lib/python3.9/site-packages/distributed/protocol/core.py", line 158, in loads return msgpack.loads( File "msgpack/_unpacker.pyx", line 194, in msgpack._cmsgpack.unpackb File "/datasets/pentschev/miniconda3/envs/rn-230925/lib/python3.9/site-packages/distributed/protocol/core.py", line 150, in _decode_default return pickle.loads(sub_header["pickled-obj"], buffers=sub_frames) File "/datasets/pentschev/miniconda3/envs/rn-230925/lib/python3.9/site-packages/distributed/protocol/pickle.py", line 94, in loads return pickle.loads(x, buffers=buffers) File "/datasets/pentschev/miniconda3/envs/rn-230925/lib/python3.9/site-packages/numpy/core/numeric.py", line 1875, in _frombuffer return frombuffer(buf, dtype=dtype).reshape(shape, order=order) ValueError: cannot reshape array of size 5002 into shape (10000,) 2023-09-26 12:25:45,032 - distributed.core - ERROR - cannot reshape array of size 5002 into shape (10000,) Traceback (most recent call last): File "/datasets/pentschev/miniconda3/envs/rn-230925/lib/python3.9/site-packages/distributed/utils.py", line 801, in wrapper return await func(*args, **kwargs) File "/datasets/pentschev/miniconda3/envs/rn-230925/lib/python3.9/site-packages/distributed/comm/ucx.py", line 407, in read msg = await from_frames( File "/datasets/pentschev/miniconda3/envs/rn-230925/lib/python3.9/site-packages/distributed/comm/utils.py", line 100, in from_frames res = _from_frames() File "/datasets/pentschev/miniconda3/envs/rn-230925/lib/python3.9/site-packages/distributed/comm/utils.py", line 83, in _from_frames return protocol.loads( File "/datasets/pentschev/miniconda3/envs/rn-230925/lib/python3.9/site-packages/distributed/protocol/core.py", line 158, in loads return msgpack.loads( File "msgpack/_unpacker.pyx", line 194, in msgpack._cmsgpack.unpackb File "/datasets/pentschev/miniconda3/envs/rn-230925/lib/python3.9/site-packages/distributed/protocol/core.py", line 150, in _decode_default return pickle.loads(sub_header["pickled-obj"], buffers=sub_frames) File "/datasets/pentschev/miniconda3/envs/rn-230925/lib/python3.9/site-packages/distributed/protocol/pickle.py", line 94, in loads return pickle.loads(x, buffers=buffers) File "/datasets/pentschev/miniconda3/envs/rn-230925/lib/python3.9/site-packages/numpy/core/numeric.py", line 1875, in _frombuffer return frombuffer(buf, dtype=dtype).reshape(shape, order=order) ValueError: cannot reshape array of size 5002 into shape (10000,) 2023-09-26 12:25:45,050 - distributed.protocol.core - CRITICAL - Failed to deserialize Traceback (most recent call last): File "/datasets/pentschev/miniconda3/envs/rn-230925/lib/python3.9/site-packages/distributed/protocol/core.py", line 158, in loads return msgpack.loads( File "msgpack/_unpacker.pyx", line 194, in msgpack._cmsgpack.unpackb File "/datasets/pentschev/miniconda3/envs/rn-230925/lib/python3.9/site-packages/distributed/protocol/core.py", line 150, in _decode_default return pickle.loads(sub_header["pickled-obj"], buffers=sub_frames) File "/datasets/pentschev/miniconda3/envs/rn-230925/lib/python3.9/site-packages/distributed/protocol/pickle.py", line 94, in loads return pickle.loads(x, buffers=buffers) File "/datasets/pentschev/miniconda3/envs/rn-230925/lib/python3.9/site-packages/numpy/core/numeric.py", line 1875, in _frombuffer return frombuffer(buf, dtype=dtype).reshape(shape, order=order) ValueError: cannot reshape array of size 5002 into shape (10000,) 2023-09-26 12:25:45,055 - distributed.scheduler - WARNING - Received heartbeat from unregistered worker 'ucx://10.33.225.163:40865'. 2023-09-26 12:25:45,055 - distributed.worker - ERROR - Scheduler was unaware of this worker 'ucx://10.33.225.163:40865'. Shutting down. 2023-09-26 12:25:45,079 - tornado.application - ERROR - Exception in callback functools.partial(After bisecting I found https://github.com/dask/distributed/pull/8067 to be the source of this issue, it used to complete fine before that, it still does if we replace
protocol="ucx"
withprotocol="tcp"
, which may suggest there's something missing in the serialization protocol for UCX.cc @rjzamora @madsbk who both had a look at https://github.com/dask/distributed/pull/8067 and may have thoughts on what we're missing.