ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.54k stars 5.69k forks source link

[Core] util.multiprocessing.pool: imap and imap_unordered blocking on ray.wait even though processes are complete #29466

Open minerharry opened 2 years ago

minerharry commented 2 years ago

What happened + What you expected to happen

In testing for #29453, I found that when a long process is put into the imap or imap_unordered queue, it blocks other tasks from being submitted until it finishes (see the linked issue for the script I used to find that). Specifically, if pool actor A is given a long task, actor B will execute a short task but not execute its next task until actor A finishes its task and receives its next task. This can be seen in the following outputs from the script in the repro section:

with [15,1,1,1,1,1]

(PoolActor pid=31760) waiting 15 seconds... @0.0036280155181884766
(PoolActor pid=19372) waiting 1 seconds... @0.0036280155181884766
(PoolActor pid=19372) done! @1.0085012912750244
----TIME SKIP----
(PoolActor pid=31760) done! @15.010581493377686
(PoolActor pid=31760) waiting 1 seconds... @15.528305053710938
(PoolActor pid=19372) waiting 1 seconds... @15.528305053710938
(PoolActor pid=31760) done! @16.533926963806152
(PoolActor pid=19372) done! @16.533926963806152
(PoolActor pid=31760) waiting 1 seconds... @17.04942226409912
(PoolActor pid=19372) waiting 1 seconds... @17.05042338371277
(PoolActor pid=31760) done! @18.055502891540527
(PoolActor pid=19372) done! @18.055502891540527

with [1,15,1,1,1,1,1]

(PoolActor pid=29348) waiting 1 seconds... @0.003527402877807617
(PoolActor pid=11912) waiting 15 seconds... @0.003527402877807617
(PoolActor pid=29348) done! @1.0083823204040527
(PoolActor pid=29348) waiting 1 seconds... @1.526676893234253
(PoolActor pid=29348) done! @2.530019760131836
----TIME SKIP----
(PoolActor pid=11912) done! @15.016525268554688
(PoolActor pid=29348) waiting 1 seconds... @15.531500339508057
(PoolActor pid=11912) waiting 1 seconds... @15.531500339508057
(PoolActor pid=29348) done! @16.543936729431152
(PoolActor pid=11912) done! @16.543736219406128
(PoolActor pid=29348) waiting 1 seconds... @17.063328742980957
(PoolActor pid=11912) waiting 1 seconds... @17.06241226196289
(PoolActor pid=29348) done! @18.069923400878906
(PoolActor pid=11912) done! @18.07009243965149

with [1,1,15,1,1,1,1]

(PoolActor pid=8740) waiting 1 seconds... @0.0030863285064697266
(PoolActor pid=25608) waiting 1 seconds... @0.0020079612731933594
(PoolActor pid=8740) done! @1.0180027484893799
(PoolActor pid=25608) done! @1.0175530910491943
(PoolActor pid=8740) waiting 1 seconds... @1.5367352962493896
(PoolActor pid=25608) waiting 15 seconds... @1.5356879234313965
(PoolActor pid=8740) done! @2.5418169498443604
----TIME SKIP----
(PoolActor pid=25608) done! @16.54096007347107
(PoolActor pid=8740) waiting 1 seconds... @17.060018062591553
(PoolActor pid=25608) waiting 1 seconds... @17.05897283554077
(PoolActor pid=8740) done! @18.06603217124939
(PoolActor pid=25608) done! @18.066240549087524
(PoolActor pid=25608) waiting 1 seconds... @18.581287622451782
(PoolActor pid=25608) done! @19.585636138916016

With 3 processes: ([1,15,1,1,1,1,1,1])

(PoolActor pid=34548) waiting 15 seconds... @0.0033195018768310547
(PoolActor pid=12260) waiting 1 seconds... @0.0033195018768310547
(PoolActor pid=21256) waiting 1 seconds... @0.004317760467529297
(PoolActor pid=21256) done! @1.0061466693878174
(PoolActor pid=12260) done! @1.0061466693878174
(PoolActor pid=12260) waiting 1 seconds... @1.5130040645599365
(PoolActor pid=12260) done! @2.5215137004852295
----TIME SKIP----
(PoolActor pid=34548) done! @15.003873825073242
(PoolActor pid=21256) waiting 1 seconds... @15.520460605621338
(PoolActor pid=34548) waiting 1 seconds... @15.517465829849243
(PoolActor pid=12260) waiting 1 seconds... @15.521462440490723
(PoolActor pid=21256) done! @16.53088879585266
(PoolActor pid=34548) done! @16.53088879585266
(PoolActor pid=12260) done! @16.53088879585266
(PoolActor pid=34548) waiting 1 seconds... @17.045456886291504
(PoolActor pid=34548) done! @18.053963661193848

If I add a bunch of print statements to pool.py (see below) and run the previous scenario, I get the following output:

--popping new
submitting chunk to actor 0
adding object ref
submitting chunk to actor 1
--awaiting 1 unready
adding object ref
--unready refs to check: [ObjectRef(f91b78d7db9a6593da872ff1fbd78b680483be140100000001000000)]
submitting chunk to actor 2
adding object ref
waiting for next complete...
(PoolActor pid=13776) waiting 1 seconds... @0.014054298400878906
(PoolActor pid=25496) waiting 1 seconds... @0.0070416927337646484
(PoolActor pid=33216) waiting 15 seconds... @0.01007080078125
(PoolActor pid=13776) done! @1.0150954723358154
(PoolActor pid=25496) done! @1.015376091003418
--result ObjectRef(f91b78d7db9a6593da872ff1fbd78b680483be140100000001000000) returned, awaiting ready get
--ready got
--popping new
next complete, submitting chunk...
submitting chunk to actor 0
--awaiting 2 unready
adding object ref
chunk submitted
--unready refs to check:  [ObjectRef(82891771158d68c1ee3880d9677d61d7b0df983e0100000001000000), ObjectRef(8849b62d89cb30f9944eaeee92ba0f0b889de49f0100000001000000)]
waiting for next complete...
--result ObjectRef(8849b62d89cb30f9944eaeee92ba0f0b889de49f0100000001000000) returned, awaiting ready get
--ready got
--popping new
next complete, submitting chunk...
--awaiting 2 unready
submitting chunk to actor 1
adding object ref
chunk submitted
waiting for next complete...
--unready refs to check: [ObjectRef(82891771158d68c1ee3880d9677d61d7b0df983e0100000001000000), ObjectRef(80e22aed7718a125da872ff1fbd78b680483be140100000001000000)]      
(PoolActor pid=25496) waiting 1 seconds... @1.5388407707214355
(PoolActor pid=25496) done! @2.5542988777160645
--result ObjectRef(80e22aed7718a125da872ff1fbd78b680483be140100000001000000) returned, awaiting ready get
--ready got
--popping new
next complete, submitting chunk...
submitting chunk to actor 2
--awaiting 2 unready
adding object ref
chunk submitted
waiting for next complete...
--unready refs to check: [ObjectRef(82891771158d68c1ee3880d9677d61d7b0df983e0100000001000000), ObjectRef(359ec6ce30d3ca2dee3880d9677d61d7b0df983e0100000001000000)]      
(PoolActor pid=13776) waiting 1 seconds... @3.071631908416748
(PoolActor pid=13776) done! @4.074235916137695
----TIME SKIP----
(PoolActor pid=33216) done! @15.023739576339722
--result ObjectRef(82891771158d68c1ee3880d9677d61d7b0df983e0100000001000000) returned, awaiting ready get
--ready got
--popping new
next complete, submitting chunk...
--awaiting 2 unready
submitting chunk to actor 0
adding object ref
--unready refs to check: [ObjectRef(359ec6ce30d3ca2dee3880d9677d61d7b0df983e0100000001000000), ObjectRef(1e8ff6d236132784944eaeee92ba0f0b889de49f0100000001000000)]
chunk submitted
waiting for next complete...
--result ObjectRef(1e8ff6d236132784944eaeee92ba0f0b889de49f0100000001000000) returned, awaiting ready get
--ready got
--popping new
next complete, submitting chunk...
submitting chunk to actor 1
--awaiting 2 unready
adding object ref
chunk submitted
waiting for next complete...
--unready refs to check: [ObjectRef(359ec6ce30d3ca2dee3880d9677d61d7b0df983e0100000001000000), ObjectRef(85748392bcd969ccda872ff1fbd78b680483be140100000001000000)]       
(PoolActor pid=25496) waiting 1 seconds... @15.542959213256836
(PoolActor pid=33216) waiting 1 seconds... @15.533987522125244
(PoolActor pid=25496) done! @16.554728031158447
(PoolActor pid=33216) done! @16.53400707244873
--result ObjectRef(359ec6ce30d3ca2dee3880d9677d61d7b0df983e0100000001000000) returned, awaiting ready get
--ready got
--popping new
next complete, submitting chunk...
--awaiting 2 unready
submitting chunk to actor 2 
chunk submitted
waiting for next complete...
--unready refs to check: [ObjectRef(85748392bcd969ccda872ff1fbd78b680483be140100000001000000), ObjectRef(d695f922effe6d99ee3880d9677d61d7b0df983e0100000001000000)]
--result ObjectRef(85748392bcd969ccda872ff1fbd78b680483be140100000001000000) returned, awaiting ready get
--ready got
--popping new
next complete, submitting chunk...
(PoolActor pid=33216) waiting 1 seconds... @17.046282052993774
chunk submitted 
waiting for next complete...
--awaiting 1 unready
--unready refs to check: [ObjectRef(d695f922effe6d99ee3880d9677d61d7b0df983e0100000001000000)]
(PoolActor pid=33216) done! @18.054298877716064
--result ObjectRef(d695f922effe6d99ee3880d9677d61d7b0df983e0100000001000000) returned, awaiting ready get
--ready got
next complete, submitting chunk...
chunk submitted

Note: print statements in the asynchronous ResultThread run method marked by a '--' at the beginning; also the original output was rather messy because async threads printing to stdout, I prettied up the lines while preserving order as best I could

Most importantly, the main thread appears to be blocked by "waiting for next complete", in which the main thread is querying the result thread for the next completed item. This happens because it's waiting for a new item to be added to the _ready_index_queue, which happens in the "run" async loop. During the time skip, that thread is between "awaiting unready 2" and "result returned", which means it's waiting for one of the two active processes to finish on the ray.wait line in run. However, there are three actors in the pool - where's the third process? Following the order of the print statements, the result thread begins awaiting unready after the chunk is set to be submitted ("submitting chunk to actor 2") - but the result thread starts waiting for a process to finish ("--awaiting unready 2") before the object reference is actually submitted to the result thready ("adding object ref"). It seems there is a race condition as to whether the most recently submitted object - whose submission happens as a trigger to an object being put in the next_ready_index queue - gets added to the result thread's object refs before the result thread gets through its loop. One simple fix would be to add a timeout to the ray.wait statement and periodically check for new objects being submitted.

However, I think there's a bigger issue here than that: The only long task in the list is the 15s one (object reference 8289... because that's what completes after the after the time skip), meaning the other reference Ray.wait is waiting on (359e) must be completed (unless something very weird is happening). That means Ray.wait is refusing to return even though one of its references is complete! I have been unable to replicate this behavior in my own separate testing, so I think something more complicated is going on here that I'm not detecting.

As mentioned in the previous issue (#29453), this behavior is (mostly?) bypassed by using more intelligent scheduling. I used a slightly improved version of _get_next_actor_index from the previous issue and all short tasks were successfully completed before the long task. I haven't done any rigorous testing though.

Updated _get_next_free_actor_index:

    def _get_next_free_actor_index(self):
        ref_map = {actor[0].ping.remote():(actor[1],k) for k,actor in enumerate(self._pool._actor_pool)}; #objref:count,actorindex
        ready, _ = ray.wait(list(ref_map.keys()));
        ready_actors = [ref_map[ref] for ref in ready];
        ready_index = min(ready_actors,key=lambda r:r[0])[1]; #get ready actor with fewest tasks
        return ready_index;

Versions / Dependencies

Ray version: 3.0.0.dev0 (nightly build accessed on 10/17) Python Version: 3.10 OS: Windows 10

Reproduction script

test script:

import time
import ray
from ray.util.multiprocessing import Pool 

ray.init();
num_procs = 2;
p = Pool(num_procs);

lengths = {see above}

start = time.time();

def wait_time(t):
    print(f"waiting {t} seconds...",f"@{time.time()-start}")
    time.sleep(t);
    print(f"done!",f"@{time.time()-start}");
    time.sleep(0.5);
res = p.imap_unordered(wait_time,lengths,chunksize=1);
list(res); #clear queue for iterators

modified pool file w/ debug statements: (also has the next actor index function, but unused)

import collections
import copy
import gc
import itertools
import logging
import os
import queue
import random
import sys
import threading
import time
from multiprocessing import TimeoutError
from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Tuple

import ray
from ray.util import log_once

try:
    from joblib._parallel_backends import SafeFunction
    from joblib.parallel import BatchedCalls, parallel_backend
except ImportError:
    BatchedCalls = None
    parallel_backend = None
    SafeFunction = None

logger = logging.getLogger(__name__)

RAY_ADDRESS_ENV = "RAY_ADDRESS"

def _put_in_dict_registry(
    obj: Any, registry_hashable: Dict[Hashable, ray.ObjectRef]
) -> ray.ObjectRef:
    if obj not in registry_hashable:
        ret = ray.put(obj)
        registry_hashable[obj] = ret
    else:
        ret = registry_hashable[obj]
    return ret

def _put_in_list_registry(
    obj: Any, registry: List[Tuple[Any, ray.ObjectRef]]
) -> ray.ObjectRef:
    try:
        ret = next((ref for o, ref in registry if o is obj))
    except StopIteration:
        ret = ray.put(obj)
        registry.append((obj, ret))
    return ret

def ray_put_if_needed(
    obj: Any,
    registry: Optional[List[Tuple[Any, ray.ObjectRef]]] = None,
    registry_hashable: Optional[Dict[Hashable, ray.ObjectRef]] = None,
) -> ray.ObjectRef:
    """ray.put obj in object store if it's not an ObjRef and bigger than 100 bytes,
    with support for list and dict registries"""
    if isinstance(obj, ray.ObjectRef) or sys.getsizeof(obj) < 100:
        return obj
    ret = obj
    if registry_hashable is not None:
        try:
            ret = _put_in_dict_registry(obj, registry_hashable)
        except TypeError:
            if registry is not None:
                ret = _put_in_list_registry(obj, registry)
    elif registry is not None:
        ret = _put_in_list_registry(obj, registry)
    return ret

def ray_get_if_needed(obj: Any) -> Any:
    """If obj is an ObjectRef, do ray.get, otherwise return obj"""
    if isinstance(obj, ray.ObjectRef):
        return ray.get(obj)
    return obj

if BatchedCalls is not None:

    class RayBatchedCalls(BatchedCalls):
        """Joblib's BatchedCalls with basic Ray object store management

        This functionality is provided through the put_items_in_object_store,
        which uses external registries (list and dict) containing objects
        and their ObjectRefs."""

        def put_items_in_object_store(
            self,
            registry: Optional[List[Tuple[Any, ray.ObjectRef]]] = None,
            registry_hashable: Optional[Dict[Hashable, ray.ObjectRef]] = None,
        ):
            """Puts all applicable (kw)args in self.items in object store

            Takes two registries - list for unhashable objects and dict
            for hashable objects. The registries are a part of a Pool object.
            The method iterates through all entries in items list (usually,
            there will be only one, but the number depends on joblib Parallel
            settings) and puts all of the args and kwargs into the object
            store, updating the registries.
            If an arg or kwarg is already in a registry, it will not be
            put again, and instead, the cached object ref will be used."""
            new_items = []
            for func, args, kwargs in self.items:
                args = [
                    ray_put_if_needed(arg, registry, registry_hashable) for arg in args
                ]
                kwargs = {
                    k: ray_put_if_needed(v, registry, registry_hashable)
                    for k, v in kwargs.items()
                }
                new_items.append((func, args, kwargs))
            self.items = new_items

        def __call__(self):
            # Exactly the same as in BatchedCalls, with the
            # difference being that it gets args and kwargs from
            # object store (which have been put in there by
            # put_items_in_object_store)

            # Set the default nested backend to self._backend but do
            # not set the change the default number of processes to -1
            with parallel_backend(self._backend, n_jobs=self._n_jobs):
                return [
                    func(
                        *[ray_get_if_needed(arg) for arg in args],
                        **{k: ray_get_if_needed(v) for k, v in kwargs.items()},
                    )
                    for func, args, kwargs in self.items
                ]

        def __reduce__(self):
            # Exactly the same as in BatchedCalls, with the
            # difference being that it returns RayBatchedCalls
            # instead
            if self._reducer_callback is not None:
                self._reducer_callback()
            # no need pickle the callback.
            return (
                RayBatchedCalls,
                (self.items, (self._backend, self._n_jobs), None, self._pickle_cache),
            )

else:
    RayBatchedCalls = None

# Helper function to divide a by b and round the result up.
def div_round_up(a, b):
    return -(-a // b)

class PoolTaskError(Exception):
    def __init__(self, underlying):
        self.underlying = underlying

class ResultThread(threading.Thread):
    """Thread that collects results from distributed actors.

    It winds down when either:
        - A pre-specified number of objects has been processed
        - When the END_SENTINEL (submitted through self.add_object_ref())
            has been received and all objects received before that have been
            processed.

    Initialize the thread with total_object_refs = float('inf') to wait for the
    END_SENTINEL.

    Args:
        object_refs (List[RayActorObjectRefs]): ObjectRefs to Ray Actor calls.
            Thread tracks whether they are ready. More ObjectRefs may be added
            with add_object_ref (or _add_object_ref internally) until the object
            count reaches total_object_refs.
        single_result: Should be True if the thread is managing function
            with a single result (like apply_async). False if the thread is managing
            a function with a List of results.
        callback: called only once at the end of the thread
            if no results were errors. If single_result=True, and result is
            not an error, callback is invoked with the result as the only
            argument. If single_result=False, callback is invoked with
            a list of all the results as the only argument.
        error_callback: called only once on the first result
            that errors. Should take an Exception as the only argument.
            If no result errors, this callback is not called.
        total_object_refs: Number of ObjectRefs that this thread
            expects to be ready. May be more than len(object_refs) since
            more ObjectRefs can be submitted after the thread starts.
            If None, defaults to len(object_refs). If float("inf"), thread runs
            until END_SENTINEL (submitted through self.add_object_ref())
            has been received and all objects received before that have
            been processed.
    """

    END_SENTINEL = None

    def __init__(
        self,
        object_refs: list,
        single_result: bool = False,
        callback: callable = None,
        error_callback: callable = None,
        total_object_refs: Optional[int] = None,
    ):
        threading.Thread.__init__(self, daemon=True)
        self._got_error = False
        self._object_refs = []
        self._num_ready = 0
        self._results = []
        self._ready_index_queue = queue.Queue()
        self._single_result = single_result
        self._callback = callback
        self._error_callback = error_callback
        self._total_object_refs = total_object_refs or len(object_refs)
        self._indices = {}
        # Thread-safe queue used to add ObjectRefs to fetch after creating
        # this thread (used to lazily submit for imap and imap_unordered).
        self._new_object_refs = queue.Queue()
        for object_ref in object_refs:
            self._add_object_ref(object_ref)

    def _add_object_ref(self, object_ref):
        self._indices[object_ref] = len(self._object_refs)
        self._object_refs.append(object_ref)
        self._results.append(None)

    def add_object_ref(self, object_ref):
        self._new_object_refs.put(object_ref)

    def run(self):
        unready = copy.copy(self._object_refs)
        aggregated_batch_results = []

        # Run for a specific number of objects if self._total_object_refs is finite.
        # Otherwise, process all objects received prior to the stop signal, given by
        # self.add_object(END_SENTINEL).
        while self._num_ready < self._total_object_refs:
            # Get as many new IDs from the queue as possible without blocking,
            # unless we have no IDs to wait on, in which case we block.
            print("--popping new");
            while True:
                try:
                    block = len(unready) == 0
                    new_object_ref = self._new_object_refs.get(block=block)
                    if new_object_ref is self.END_SENTINEL:
                        # Receiving the END_SENTINEL object is the signal to stop.
                        # Store the total number of objects.
                        self._total_object_refs = len(self._object_refs)
                    else:
                        self._add_object_ref(new_object_ref)
                        unready.append(new_object_ref)
                except queue.Empty:
                    # queue.Empty means no result was retrieved if block=False.
                    break
            print("--awaiting",len(unready),"unready");
            print("--unready refs to check:",unready);
            [ready_id], unready = ray.wait(unready, num_returns=1)
            print(f"--result {ready_id} returned, awaiting ready get")
            try:
                batch = ray.get(ready_id)
            except ray.exceptions.RayError as e:
                batch = [e]
            print("--ready got");

            # The exception callback is called only once on the first result
            # that errors. If no result errors, it is never called.
            if not self._got_error:
                for result in batch:
                    if isinstance(result, Exception):
                        self._got_error = True
                        if self._error_callback is not None:
                            self._error_callback(result)
                        break
                    else:
                        aggregated_batch_results.append(result)

            self._num_ready += 1
            self._results[self._indices[ready_id]] = batch
            self._ready_index_queue.put(self._indices[ready_id])

        # The regular callback is called only once on the entire List of
        # results as long as none of the results were errors. If any results
        # were errors, the regular callback is never called; instead, the
        # exception callback is called on the first erroring result.
        #
        # This callback is called outside the while loop to ensure that it's
        # called on the entire list of results– not just a single batch.
        if not self._got_error and self._callback is not None:
            if not self._single_result:
                self._callback(aggregated_batch_results)
            else:
                # On a thread handling a function with a single result
                # (e.g. apply_async), we call the callback on just that result
                # instead of on a list encaspulating that result
                self._callback(aggregated_batch_results[0])

    def got_error(self):
        # Should only be called after the thread finishes.
        return self._got_error

    def result(self, index):
        # Should only be called on results that are ready.
        return self._results[index]

    def results(self):
        # Should only be called after the thread finishes.
        return self._results

    def next_ready_index(self, timeout=None):
        try:
            return self._ready_index_queue.get(timeout=timeout)
        except queue.Empty:
            # queue.Queue signals a timeout by raising queue.Empty.
            raise TimeoutError

class AsyncResult:
    """An asynchronous interface to task results.

    This should not be constructed directly.
    """

    def __init__(
        self, chunk_object_refs, callback=None, error_callback=None, single_result=False
    ):
        self._single_result = single_result
        self._result_thread = ResultThread(
            chunk_object_refs, single_result, callback, error_callback
        )
        self._result_thread.start()

    def wait(self, timeout=None):
        """
        Returns once the result is ready or the timeout expires (does not
        raise TimeoutError).

        Args:
            timeout: timeout in milliseconds.
        """

        self._result_thread.join(timeout)

    def get(self, timeout=None):
        self.wait(timeout)
        if self._result_thread.is_alive():
            raise TimeoutError

        results = []
        for batch in self._result_thread.results():
            for result in batch:
                if isinstance(result, PoolTaskError):
                    raise result.underlying
                elif isinstance(result, Exception):
                    raise result
            results.extend(batch)

        if self._single_result:
            return results[0]

        return results

    def ready(self):
        """
        Returns true if the result is ready, else false if the tasks are still
        running.
        """

        return not self._result_thread.is_alive()

    def successful(self):
        """
        Returns true if none of the submitted tasks errored, else false. Should
        only be called once the result is ready (can be checked using `ready`).
        """

        if not self.ready():
            raise ValueError(f"{self!r} not ready")
        return not self._result_thread.got_error()

class IMapIterator:
    """Base class for OrderedIMapIterator and UnorderedIMapIterator."""

    def __init__(self, pool, func, iterable, chunksize=None):
        self._pool = pool
        self._func = func
        self._next_chunk_index = 0
        self._finished_iterating = False
        # List of bools indicating if the given chunk is ready or not for all
        # submitted chunks. Ordering mirrors that in the in the ResultThread.
        self._submitted_chunks = []
        self._ready_objects = collections.deque()
        try:
            self._iterator = iter(iterable)
        except TypeError:
            # for compatibility with prior releases, encapsulate non-iterable in a list
            iterable = [iterable]
            self._iterator = iter(iterable)
        if isinstance(iterable, collections.abc.Iterator):
            # Got iterator (which has no len() function).
            # Make default chunksize 1 instead of using _calculate_chunksize().
            # Indicate unknown queue length, requiring explicit stopping.
            self._chunksize = chunksize or 1
            result_list_size = float("inf")
        else:
            self._chunksize = chunksize or pool._calculate_chunksize(iterable)
            result_list_size = div_round_up(len(iterable), chunksize)

        self._result_thread = ResultThread([], total_object_refs=result_list_size)
        self._result_thread.start()

        for _ in range(len(self._pool._actor_pool)):
            self._submit_next_chunk()

    def _submit_next_chunk(self):
        # The full iterable has already been submitted, so no-op.
        if self._finished_iterating:
            return

        actor_index = len(self._submitted_chunks) % len(self._pool._actor_pool)
        # actor_index = self._get_next_free_actor_index();
        print(f"submitting chunk to actor {actor_index}");
        chunk_iterator = itertools.islice(self._iterator, self._chunksize)

        # Check whether we have run out of samples.
        # This consumes the original iterator, so we convert to a list and back
        chunk_list = list(chunk_iterator)
        if len(chunk_list) < self._chunksize:
            # Reached end of self._iterator
            self._finished_iterating = True
            if len(chunk_list) == 0:
                # Nothing to do, return.
                return
        chunk_iterator = iter(chunk_list)

        new_chunk_id = self._pool._submit_chunk(
            self._func, chunk_iterator, self._chunksize, actor_index
        )
        self._submitted_chunks.append(False)
        # Wait for the result
        print("adding object ref");
        self._result_thread.add_object_ref(new_chunk_id)
        # If we submitted the final chunk, notify the result thread
        if self._finished_iterating:
            self._result_thread.add_object_ref(ResultThread.END_SENTINEL)

    def _get_next_free_actor_index(self):
        ref_map = {actor[0].ping.remote():(actor[1],k) for k,actor in enumerate(self._pool._actor_pool)}; #objref:count,actorindex
        ready, _ = ray.wait(list(ref_map.keys()));
        ready_actors = [ref_map[ref] for ref in ready];
        ready_index = min(ready_actors,key=lambda r:r[0])[1]; #get ready actor with fewest tasks
        return ready_index;

    def __iter__(self):
        return self

    def __next__(self):
        return self.next()

    def next(self):
        # Should be implemented by subclasses.
        raise NotImplementedError

class OrderedIMapIterator(IMapIterator):
    """Iterator to the results of tasks submitted using `imap`.

    The results are returned in the same order that they were submitted, even
    if they don't finish in that order. Only one batch of tasks per actor
    process is submitted at a time - the rest are submitted as results come in.

    Should not be constructed directly.
    """

    def next(self, timeout=None):
        if len(self._ready_objects) == 0:
            if self._finished_iterating and (
                self._next_chunk_index == len(self._submitted_chunks)
            ):
                # Finish when all chunks have been dispatched and processed
                # Notify the calling process that the work is done.
                raise StopIteration

            # This loop will break when the next index in order is ready or
            # self._result_thread.next_ready_index() raises a timeout.
            index = -1
            while index != self._next_chunk_index:
                start = time.time()
                index = self._result_thread.next_ready_index(timeout=timeout)
                self._submit_next_chunk()
                self._submitted_chunks[index] = True
                if timeout is not None:
                    timeout = max(0, timeout - (time.time() - start))

            while (
                self._next_chunk_index < len(self._submitted_chunks)
                and self._submitted_chunks[self._next_chunk_index]
            ):
                for result in self._result_thread.result(self._next_chunk_index):
                    self._ready_objects.append(result)
                self._next_chunk_index += 1

        return self._ready_objects.popleft()

class UnorderedIMapIterator(IMapIterator):
    """Iterator to the results of tasks submitted using `imap`.

    The results are returned in the order that they finish. Only one batch of
    tasks per actor process is submitted at a time - the rest are submitted as
    results come in.

    Should not be constructed directly.
    """

    def next(self, timeout=None):
        if len(self._ready_objects) == 0:
            if self._finished_iterating and (
                self._next_chunk_index == len(self._submitted_chunks)
            ):
                # Finish when all chunks have been dispatched and processed
                # Notify the calling process that the work is done.
                raise StopIteration
            print("waiting for next complete...");
            index = self._result_thread.next_ready_index(timeout=timeout)
            print("next complete, submitting chunk...")
            self._submit_next_chunk()
            print("chunk submitted");

            for result in self._result_thread.result(index):
                self._ready_objects.append(result)
            self._next_chunk_index += 1

        return self._ready_objects.popleft()

@ray.remote(num_cpus=0)
class PoolActor:
    """Actor used to process tasks submitted to a Pool."""

    def __init__(self, initializer=None, initargs=None):
        if initializer:
            initargs = initargs or ()
            initializer(*initargs)

    def ping(self):
        # Used to wait for this actor to be initialized.
        pass

    def run_batch(self, func, batch):
        results = []
        for args, kwargs in batch:
            args = args or ()
            kwargs = kwargs or {}
            try:
                results.append(func(*args, **kwargs))
            except Exception as e:
                results.append(PoolTaskError(e))
        return results

# https://docs.python.org/3/library/multiprocessing.html#module-multiprocessing.pool
class Pool:
    """A pool of actor processes that is used to process tasks in parallel.

    Args:
        processes: number of actor processes to start in the pool. Defaults to
            the number of cores in the Ray cluster if one is already running,
            otherwise the number of cores on this machine.
        initializer: function to be run in each actor when it starts up.
        initargs: iterable of arguments to the initializer function.
        maxtasksperchild: maximum number of tasks to run in each actor process.
            After a process has executed this many tasks, it will be killed and
            replaced with a new one.
        ray_address: address of the Ray cluster to run on. If None, a new local
            Ray cluster will be started on this machine. Otherwise, this will
            be passed to `ray.init()` to connect to a running cluster. This may
            also be specified using the `RAY_ADDRESS` environment variable.
        ray_remote_args: arguments used to configure the Ray Actors making up
            the pool.
    """

    def __init__(
        self,
        processes: Optional[int] = None,
        initializer: Optional[Callable] = None,
        initargs: Optional[Iterable] = None,
        maxtasksperchild: Optional[int] = None,
        context: Any = None,
        ray_address: Optional[str] = None,
        ray_remote_args: Optional[Dict[str, Any]] = None,
    ):
        ray._private.usage.usage_lib.record_library_usage("util.multiprocessing.Pool")

        self._closed = False
        self._initializer = initializer
        self._initargs = initargs
        self._maxtasksperchild = maxtasksperchild or -1
        self._actor_deletion_ids = []
        self._registry: List[Tuple[Any, ray.ObjectRef]] = []
        self._registry_hashable: Dict[Hashable, ray.ObjectRef] = {}
        self._current_index = 0
        self._ray_remote_args = ray_remote_args or {}
        self._pool_actor = None

        if context and log_once("context_argument_warning"):
            logger.warning(
                "The 'context' argument is not supported using "
                "ray. Please refer to the documentation for how "
                "to control ray initialization."
            )

        processes = self._init_ray(processes, ray_address)
        self._start_actor_pool(processes)

    def _init_ray(self, processes=None, ray_address=None):
        # Initialize ray. If ray is already initialized, we do nothing.
        # Else, the priority is:
        # ray_address argument > RAY_ADDRESS > start new local cluster.
        if not ray.is_initialized():
            # Cluster mode.
            if ray_address is None and (
                RAY_ADDRESS_ENV in os.environ
                or ray._private.utils.read_ray_address() is not None
            ):
                ray.init()
            elif ray_address is not None:
                init_kwargs = {}
                if ray_address == "local":
                    init_kwargs["num_cpus"] = processes
                ray.init(address=ray_address, **init_kwargs)
            # Local mode.
            else:
                ray.init(num_cpus=processes)

        ray_cpus = int(ray._private.state.cluster_resources()["CPU"])
        if processes is None:
            processes = ray_cpus
        if processes <= 0:
            raise ValueError("Processes in the pool must be >0.")
        if ray_cpus < processes:
            raise ValueError(
                "Tried to start a pool with {} processes on an "
                "existing ray cluster, but there are only {} "
                "CPUs in the ray cluster.".format(processes, ray_cpus)
            )

        return processes

    def _start_actor_pool(self, processes):
        self._pool_actor = None
        self._actor_pool = [self._new_actor_entry() for _ in range(processes)]
        ray.get([actor.ping.remote() for actor, _ in self._actor_pool])

    def _wait_for_stopping_actors(self, timeout=None):
        if len(self._actor_deletion_ids) == 0:
            return
        if timeout is not None:
            timeout = float(timeout)

        _, deleting = ray.wait(
            self._actor_deletion_ids,
            num_returns=len(self._actor_deletion_ids),
            timeout=timeout,
        )
        self._actor_deletion_ids = deleting

    def _stop_actor(self, actor):
        # Check and clean up any outstanding IDs corresponding to deletions.
        self._wait_for_stopping_actors(timeout=0.0)
        # The deletion task will block until the actor has finished executing
        # all pending tasks.
        self._actor_deletion_ids.append(actor.__ray_terminate__.remote())

    def _new_actor_entry(self):
        # NOTE(edoakes): The initializer function can't currently be used to
        # modify the global namespace (e.g., import packages or set globals)
        # due to a limitation in cloudpickle.
        # Cache the PoolActor with options
        if not self._pool_actor:
            self._pool_actor = PoolActor.options(**self._ray_remote_args)
        return (self._pool_actor.remote(self._initializer, self._initargs), 0)

    def _next_actor_index(self):
        if self._current_index == len(self._actor_pool) - 1:
            self._current_index = 0
        else:
            self._current_index += 1
        return self._current_index

    # Batch should be a list of tuples: (args, kwargs).
    def _run_batch(self, actor_index, func, batch):
        actor, count = self._actor_pool[actor_index]
        object_ref = actor.run_batch.remote(func, batch)
        count += 1
        assert self._maxtasksperchild == -1 or count <= self._maxtasksperchild
        if count == self._maxtasksperchild:
            self._stop_actor(actor)
            actor, count = self._new_actor_entry()
        self._actor_pool[actor_index] = (actor, count)
        return object_ref

    def apply(
        self,
        func: Callable,
        args: Optional[Tuple] = None,
        kwargs: Optional[Dict] = None,
    ):
        """Run the given function on a random actor process and return the
        result synchronously.

        Args:
            func: function to run.
            args: optional arguments to the function.
            kwargs: optional keyword arguments to the function.

        Returns:
            The result.
        """

        return self.apply_async(func, args, kwargs).get()

    def apply_async(
        self,
        func: Callable,
        args: Optional[Tuple] = None,
        kwargs: Optional[Dict] = None,
        callback: Callable[[Any], None] = None,
        error_callback: Callable[[Exception], None] = None,
    ):
        """Run the given function on a random actor process and return an
        asynchronous interface to the result.

        Args:
            func: function to run.
            args: optional arguments to the function.
            kwargs: optional keyword arguments to the function.
            callback: callback to be executed on the result once it is finished
                only if it succeeds.
            error_callback: callback to be executed the result once it is
                finished only if the task errors. The exception raised by the
                task will be passed as the only argument to the callback.

        Returns:
            AsyncResult containing the result.
        """

        self._check_running()
        func = self._convert_to_ray_batched_calls_if_needed(func)
        object_ref = self._run_batch(self._next_actor_index(), func, [(args, kwargs)])
        return AsyncResult([object_ref], callback, error_callback, single_result=True)

    def _convert_to_ray_batched_calls_if_needed(self, func: Callable) -> Callable:
        """Convert joblib's BatchedCalls to RayBatchedCalls for ObjectRef caching.

        This converts joblib's BatchedCalls callable, which is a collection of
        functions with their args and kwargs to be ran sequentially in an
        Actor, to a RayBatchedCalls callable, which provides identical
        functionality in addition to a method which ensures that common
        args and kwargs are put into the object store just once, saving time
        and memory. That method is then ran.

        If func is not a BatchedCalls instance, it is returned without changes.

        The ObjectRefs are cached inside two registries (_registry and
        _registry_hashable), which are common for the entire Pool and are
        cleaned on close."""
        if RayBatchedCalls is None:
            return func
        orginal_func = func
        # SafeFunction is a Python 2 leftover and can be
        # safely removed.
        if isinstance(func, SafeFunction):
            func = func.func
        if isinstance(func, BatchedCalls):
            func = RayBatchedCalls(
                func.items,
                (func._backend, func._n_jobs),
                func._reducer_callback,
                func._pickle_cache,
            )
            # go through all the items and replace args and kwargs with
            # ObjectRefs, caching them in registries
            func.put_items_in_object_store(self._registry, self._registry_hashable)
        else:
            func = orginal_func
        return func

    def _calculate_chunksize(self, iterable):
        chunksize, extra = divmod(len(iterable), len(self._actor_pool) * 4)
        if extra:
            chunksize += 1
        return chunksize

    def _submit_chunk(self, func, iterator, chunksize, actor_index, unpack_args=False):
        chunk = []
        while len(chunk) < chunksize:
            try:
                args = next(iterator)
                if not unpack_args:
                    args = (args,)
                chunk.append((args, {}))
            except StopIteration:
                break

        # Nothing to submit. The caller should prevent this.
        assert len(chunk) > 0

        return self._run_batch(actor_index, func, chunk)

    def _chunk_and_run(self, func, iterable, chunksize=None, unpack_args=False):
        if not hasattr(iterable, "__len__"):
            iterable = list(iterable)

        if chunksize is None:
            chunksize = self._calculate_chunksize(iterable)

        iterator = iter(iterable)
        chunk_object_refs = []
        while len(chunk_object_refs) * chunksize < len(iterable):
            actor_index = len(chunk_object_refs) % len(self._actor_pool)
            chunk_object_refs.append(
                self._submit_chunk(
                    func, iterator, chunksize, actor_index, unpack_args=unpack_args
                )
            )

        return chunk_object_refs

    def _map_async(
        self,
        func,
        iterable,
        chunksize=None,
        unpack_args=False,
        callback=None,
        error_callback=None,
    ):
        self._check_running()
        object_refs = self._chunk_and_run(
            func, iterable, chunksize=chunksize, unpack_args=unpack_args
        )
        return AsyncResult(object_refs, callback, error_callback)

    def map(self, func: Callable, iterable: Iterable, chunksize: Optional[int] = None):
        """Run the given function on each element in the iterable round-robin
        on the actor processes and return the results synchronously.

        Args:
            func: function to run.
            iterable: iterable of objects to be passed as the sole argument to
                func.
            chunksize: number of tasks to submit as a batch to each actor
                process. If unspecified, a suitable chunksize will be chosen.

        Returns:
            A list of results.
        """

        return self._map_async(
            func, iterable, chunksize=chunksize, unpack_args=False
        ).get()

    def map_async(
        self,
        func: Callable,
        iterable: Iterable,
        chunksize: Optional[int] = None,
        callback: Callable[[List], None] = None,
        error_callback: Callable[[Exception], None] = None,
    ):
        """Run the given function on each element in the iterable round-robin
        on the actor processes and return an asynchronous interface to the
        results.

        Args:
            func: function to run.
            iterable: iterable of objects to be passed as the only argument to
                func.
            chunksize: number of tasks to submit as a batch to each actor
                process. If unspecified, a suitable chunksize will be chosen.
            callback: Will only be called if none of the results were errors,
                and will only be called once after all results are finished.
                A Python List of all the finished results will be passed as the
                only argument to the callback.
            error_callback: callback executed on the first errored result.
                The Exception raised by the task will be passed as the only
                argument to the callback.

        Returns:
            AsyncResult
        """
        return self._map_async(
            func,
            iterable,
            chunksize=chunksize,
            unpack_args=False,
            callback=callback,
            error_callback=error_callback,
        )

    def starmap(self, func, iterable, chunksize=None):
        """Same as `map`, but unpacks each element of the iterable as the
        arguments to func like: [func(*args) for args in iterable].
        """

        return self._map_async(
            func, iterable, chunksize=chunksize, unpack_args=True
        ).get()

    def starmap_async(
        self,
        func: Callable,
        iterable: Iterable,
        callback: Callable[[List], None] = None,
        error_callback: Callable[[Exception], None] = None,
    ):
        """Same as `map_async`, but unpacks each element of the iterable as the
        arguments to func like: [func(*args) for args in iterable].
        """

        return self._map_async(
            func,
            iterable,
            unpack_args=True,
            callback=callback,
            error_callback=error_callback,
        )

    def imap(self, func: Callable, iterable: Iterable, chunksize: Optional[int] = 1):
        """Same as `map`, but only submits one batch of tasks to each actor
        process at a time.

        This can be useful if the iterable of arguments is very large or each
        task's arguments consumes a large amount of resources.

        The results are returned in the order corresponding to their arguments
        in the iterable.

        Returns:
            OrderedIMapIterator
        """

        self._check_running()
        return OrderedIMapIterator(self, func, iterable, chunksize=chunksize)

    def imap_unordered(
        self, func: Callable, iterable: Iterable, chunksize: Optional[int] = 1
    ):
        """Same as `map`, but only submits one batch of tasks to each actor
        process at a time.

        This can be useful if the iterable of arguments is very large or each
        task's arguments consumes a large amount of resources.

        The results are returned in the order that they finish.

        Returns:
            UnorderedIMapIterator
        """

        self._check_running()
        return UnorderedIMapIterator(self, func, iterable, chunksize=chunksize)

    def _check_running(self):
        if self._closed:
            raise ValueError("Pool not running")

    def __enter__(self):
        self._check_running()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.terminate()

    def close(self):
        """Close the pool.

        Prevents any more tasks from being submitted on the pool but allows
        outstanding work to finish.
        """

        self._registry.clear()
        self._registry_hashable.clear()
        for actor, _ in self._actor_pool:
            self._stop_actor(actor)
        self._closed = True
        gc.collect()

    def terminate(self):
        """Close the pool.

        Prevents any more tasks from being submitted on the pool and stops
        outstanding work.
        """

        if not self._closed:
            self.close()
        for actor, _ in self._actor_pool:
            ray.kill(actor)

    def join(self):
        """Wait for the actors in a closed pool to exit.

        If the pool was closed using `close`, this will return once all
        outstanding work is completed.

        If the pool was closed using `terminate`, this will return quickly.
        """

        if not self._closed:
            raise ValueError("Pool is still running")
        self._wait_for_stopping_actors()

Issue Severity

Medium: It is a significant difficulty but I can work around it.

hora-anyscale commented 1 year ago

Per Triage Sync: @edoakes can you review?