CNES / zcollection

Python library allowing to manipulate data split into a collection of groups stored in Zarr format.
https://zcollection.readthedocs.io/en/latest/
BSD 3-Clause "New" or "Revised" License
12 stars 3 forks source link

update() broken with Future argument (newer dask version) #12

Open robin-cls opened 1 month ago

robin-cls commented 1 month ago

Hi,

I recently bumped dask in my conda environment and zcollection.Collection.update now gives a RuntimeError when trying to access a Future object in the callback. This error appears with dask=2024.9.0, but not with older versions

image

Below is the code to reproduce the problem. It works using a local cluster and a zcollection in memory

from __future__ import annotations

from typing import Iterator
import datetime
import pprint
import dask_jobqueue
import os
import dask.distributed
import fsspec
import numpy
import dask

import zcollection
import zcollection.tests.data

cluster = dask.distributed.LocalCluster(processes=False)
client = dask.distributed.Client(cluster)

def create_dataset() -> zcollection.Dataset:
    """Create a dataset to record."""
    generator: Iterator[zcollection.Dataset] = \
        zcollection.tests.data.create_test_dataset_with_fillvalue()
    return next(generator)

zds = create_dataset()

fs = fsspec.filesystem('memory')
path = '/my_collection'

partition_handler = zcollection.partitioning.Date(('time', ), resolution='M')
collection: zcollection.Collection = zcollection.create_collection(
    'time', zds, partition_handler, path, filesystem=fs)

collection.insert(zds)
scattered = client.scatter(numpy.ones(2), broadcast=True)

def callback(zds, arg_future):
    arg_future.result()
    return {'var1': zds['var1'].values * 2}
collection.update(callback, scattered)

My preliminary analysis is that the underlying wrapper for the update() function stores the *args and *kwargs arguments. This might give dask troubles for serialization/deserialization because the Future contained in args is not directly submitted by the client.

Thomas-Z commented 4 days ago

Hello,

Your analyse is correct.

The problem is the following one:

import distributed as dist
import numpy

if __name__ == '__main__':
    cluster = dist.LocalCluster(processes=False)
    client = dist.Client(cluster)

    scattered = client.scatter(numpy.ones(2), broadcast=True)

    def callback(arg_future):
        print(arg_future)

    def wrap_update_func(func, *args, **kwargs):

        def _wrapped_function() -> None:
            func(*args, **kwargs)

        return _wrapped_function

    def do_something(func, *args):
        client = dist.get_client()
        local_func = wrap_update_func(func, *args)
        futures = client.submit(local_func)

        client.compute(futures, sync=True)

    do_something(callback, scattered)

The wrapped function wraps the parameters and dask loses track of them. This problem appeared with the 2024.2 release and I think this is similar to what they discussed in this issue.

I'm pushing something that should fix it with almost no side effect.

Something like this approach:

import distributed as dist
import numpy

if __name__ == '__main__':
    cluster = dist.LocalCluster(processes=False)
    client = dist.Client(cluster)

    scattered = client.scatter(numpy.ones(2), broadcast=True, hash=False)

    def callback(arg_not_a_future):
        print(arg_not_a_future)

    def wrap_update_func(func):

        def _wrapped_function(*args, **kwargs) -> None:
            func(*args, **kwargs)

        return _wrapped_function

    def do_something(func, *args):
        client = dist.get_client()
        local_func = wrap_update_func(func)
        futures = client.submit(local_func, *args)

        client.compute(futures, sync=True)

    do_something(callback, scattered)

Regarding your usage you'll have to adapt your callback function. It won't receive a future so you do not have to call .result() on your parameters.