Closed ggservice007 closed 3 years ago
Hey @ggservice007 thanks for bring this up! We’ve had discussions about this in the past and completely agree it would be awesome to add a RayExecutor for Prefect.
Are you planning to take a stab at the implementation? If so, we’d be more than happy to provide any guidance/help from the Ray side :)
cc @richardliaw @clarkzinzow
Let's make it together.
I have create a branch for Prefect project.
https://github.com/ggservice007/prefect/tree/ray-executor
@amogkam
Base exectuor:
https://github.com/ggservice007/prefect/blob/ray-executor/src/prefect/executors/base.py
from contextlib import contextmanager
from typing import Any, Callable, Iterator
from prefect.utilities.logging import get_logger
class Executor:
"""
Base Executor class that all other executors inherit from.
"""
def __init__(self) -> None:
self.logger = get_logger(type(self).__name__)
def __repr__(self) -> str:
return "<Executor: {}>".format(type(self).__name__)
@contextmanager
def start(self) -> Iterator[None]:
"""
Context manager for initializing execution.
Any initialization this executor needs to perform should be done in this
context manager, and torn down after yielding.
"""
yield
def submit(
self, fn: Callable, *args: Any, extra_context: dict = None, **kwargs: Any
) -> Any:
"""
Submit a function to the executor for execution. Returns a future-like object.
Args:
- fn (Callable): function that is being submitted for execution
- *args (Any): arguments to be passed to `fn`
- extra_context (dict, optional): an optional dictionary with extra information
about the submitted task
- **kwargs (Any): keyword arguments to be passed to `fn`
Returns:
- Any: a future-like object
"""
raise NotImplementedError()
def wait(self, futures: Any) -> Any:
"""
Resolves futures to their values. Blocks until the future is complete.
Args:
- futures (Any): iterable of futures to compute
Returns:
- Any: an iterable of resolved futures
"""
raise NotImplementedError()
https://github.com/ggservice007/prefect/blob/ray-executor/src/prefect/executors/dask.py
import asyncio
import logging
import uuid
import sys
import weakref
from contextlib import contextmanager
from typing import Any, Callable, Iterator, TYPE_CHECKING, Union, Optional
from prefect import context
from prefect.executors.base import Executor
from prefect.utilities.importtools import import_object
if TYPE_CHECKING:
import dask
from distributed import Future, Event
import multiprocessing.pool
import concurrent.futures
__any__ = ("DaskExecutor", "LocalDaskExecutor")
def _make_task_key(
task_name: str = "", task_index: int = None, **kwargs: Any
) -> Optional[str]:
"""A helper for generating a dask task key from fields set in `extra_context`"""
if task_name:
suffix = uuid.uuid4().hex
if task_index is not None:
return f"{task_name}-{task_index}-{suffix}"
return f"{task_name}-{suffix}"
return None
def _maybe_run(event_name: str, fn: Callable, *args: Any, **kwargs: Any) -> Any:
"""Check if the task should run against a `distributed.Event` before
starting the task. This offers stronger guarantees than distributed's
current cancellation mechanism, which only cancels pending tasks."""
import dask
from distributed import Event, get_client
try:
# Explicitly pass in the timeout from dask's config. Some versions of
# distributed hardcode this rather than using the value from the
# config. Can be removed once we bump our min requirements for
# distributed to >= 2.31.0.
timeout = dask.config.get("distributed.comm.timeouts.connect")
event = Event(event_name, client=get_client(timeout=timeout))
should_run = event.is_set()
except Exception:
# Failure to create an event is usually due to connection errors. These
# are either due to flaky behavior in distributed's comms under high
# loads, or due to the scheduler shutting down. Either way, the safest
# course here is to assume we *should* run the task still. If we guess
# wrong, we're either doing a bit of unnecessary work, or the cluster
# is shutting down and the task will be cancelled anyway.
should_run = True
if should_run:
return fn(*args, **kwargs)
class DaskExecutor(Executor):
"""
An executor that runs all functions using the `dask.distributed` scheduler.
By default a temporary `distributed.LocalCluster` is created (and
subsequently torn down) within the `start()` contextmanager. To use a
different cluster class (e.g.
[`dask_kubernetes.KubeCluster`](https://kubernetes.dask.org/)), you can
specify `cluster_class`/`cluster_kwargs`.
Alternatively, if you already have a dask cluster running, you can provide
the address of the scheduler via the `address` kwarg.
Note that if you have tasks with tags of the form `"dask-resource:KEY=NUM"`
they will be parsed and passed as
[Worker Resources](https://distributed.dask.org/en/latest/resources.html)
of the form `{"KEY": float(NUM)}` to the Dask Scheduler.
Args:
- address (string, optional): address of a currently running dask
scheduler; if one is not provided, a temporary cluster will be
created in `executor.start()`. Defaults to `None`.
- cluster_class (string or callable, optional): the cluster class to use
when creating a temporary dask cluster. Can be either the full
class name (e.g. `"distributed.LocalCluster"`), or the class itself.
- cluster_kwargs (dict, optional): addtional kwargs to pass to the
`cluster_class` when creating a temporary dask cluster.
- adapt_kwargs (dict, optional): additional kwargs to pass to `cluster.adapt`
when creating a temporary dask cluster. Note that adaptive scaling
is only enabled if `adapt_kwargs` are provided.
- client_kwargs (dict, optional): additional kwargs to use when creating a
[`dask.distributed.Client`](https://distributed.dask.org/en/latest/api.html#client).
- debug (bool, optional): When running with a local cluster, setting
`debug=True` will increase dask's logging level, providing
potentially useful debug info. Defaults to the `debug` value in
your Prefect configuration.
Using a temporary local dask cluster:
```python
executor = DaskExecutor()
Using a temporary cluster running elsewhere. Any Dask cluster class should
work, here we use [dask-cloudprovider](https://cloudprovider.dask.org):
```python
executor = DaskExecutor(
cluster_class="dask_cloudprovider.FargateCluster",
cluster_kwargs={
"image": "prefecthq/prefect:latest",
"n_workers": 5,
...
},
)
```
Connecting to an existing dask cluster
```python
executor = DaskExecutor(address="192.0.2.255:8786")
```
"""
def __init__(
self,
address: str = None,
cluster_class: Union[str, Callable] = None,
cluster_kwargs: dict = None,
adapt_kwargs: dict = None,
client_kwargs: dict = None,
debug: bool = None,
):
if address is None:
address = context.config.engine.executor.dask.address or None
if address is not None:
if cluster_class is not None or cluster_kwargs is not None:
raise ValueError(
"Cannot specify `address` and `cluster_class`/`cluster_kwargs`"
)
else:
if cluster_class is None:
cluster_class = context.config.engine.executor.dask.cluster_class
if isinstance(cluster_class, str):
cluster_class = import_object(cluster_class)
if cluster_kwargs is None:
cluster_kwargs = {}
else:
cluster_kwargs = cluster_kwargs.copy()
from distributed.deploy.local import LocalCluster
if cluster_class == LocalCluster:
if debug is None:
debug = context.config.debug
cluster_kwargs.setdefault(
"silence_logs", logging.CRITICAL if not debug else logging.WARNING
)
if adapt_kwargs is None:
adapt_kwargs = {}
if client_kwargs is None:
client_kwargs = {}
else:
client_kwargs = client_kwargs.copy()
client_kwargs.setdefault("set_as_default", False)
self.address = address
self.cluster_class = cluster_class
self.cluster_kwargs = cluster_kwargs
self.adapt_kwargs = adapt_kwargs
self.client_kwargs = client_kwargs
# Runtime attributes
self.client = None
# These are coupled - they're either both None, or both non-None.
# They're used in the case we can't forcibly kill all the dask workers,
# and need to wait for all the dask tasks to cleanup before exiting.
self._futures = None # type: Optional[weakref.WeakSet[Future]]
self._should_run_event = None # type: Optional[Event]
# A ref to a background task subscribing to dask cluster events
self._watch_dask_events_task = None # type: Optional[concurrent.futures.Future]
super().__init__()
@contextmanager
def start(self) -> Iterator[None]:
"""
Context manager for initializing execution.
Creates a `dask.distributed.Client` and yields it.
"""
if sys.platform != "win32":
# Fix for https://github.com/dask/distributed/issues/4168
import multiprocessing.popen_spawn_posix # noqa
from distributed import Client
try:
if self.address is not None:
with Client(self.address, **self.client_kwargs) as client:
self.client = client
try:
self._pre_start_yield()
yield
finally:
self._post_start_yield()
else:
with self.cluster_class(**self.cluster_kwargs) as cluster: # type: ignore
if self.adapt_kwargs:
cluster.adapt(**self.adapt_kwargs)
with Client(cluster, **self.client_kwargs) as client:
self.client = client
try:
self._pre_start_yield()
yield
finally:
self._post_start_yield()
finally:
self.client = None
async def _watch_dask_events(self) -> None:
scheduler_comm = None
comm = None
from distributed.core import rpc
try:
scheduler_comm = rpc(
self.client.scheduler.address, # type: ignore
connection_args=self.client.security.get_connection_args("client"), # type: ignore
)
# due to a bug in distributed's inproc comms, letting cancellation
# bubble up here will kill the listener. wrap with a shield to
# prevent that.
comm = await asyncio.shield(scheduler_comm.live_comm())
await comm.write({"op": "subscribe_worker_status"})
_ = await comm.read()
while True:
try:
msgs = await comm.read()
except OSError:
break
for op, msg in msgs:
if op == "add":
for worker in msg.get("workers", ()):
self.logger.debug("Worker %s added", worker)
elif op == "remove":
self.logger.debug("Worker %s removed", msg)
except asyncio.CancelledError:
pass
except Exception:
self.logger.debug(
"Failure while watching dask worker events", exc_info=True
)
finally:
if comm is not None:
try:
await comm.close()
except Exception:
pass
if scheduler_comm is not None:
scheduler_comm.close_rpc()
def _pre_start_yield(self) -> None:
from distributed import Event
is_inproc = self.client.scheduler.address.startswith("inproc") # type: ignore
if self.address is not None or is_inproc:
self._futures = weakref.WeakSet()
self._should_run_event = Event(
f"prefect-{uuid.uuid4().hex}", client=self.client
)
self._should_run_event.set()
self._watch_dask_events_task = asyncio.run_coroutine_threadsafe(
self._watch_dask_events(), self.client.loop.asyncio_loop # type: ignore
)
def _post_start_yield(self) -> None:
from distributed import wait
if self._watch_dask_events_task is not None:
try:
self._watch_dask_events_task.cancel()
except Exception:
pass
self._watch_dask_events_task = None
if self._should_run_event is not None:
# Multipart cleanup, ignoring exceptions in each stage
# 1.) Stop pending tasks from starting
try:
self._should_run_event.clear()
except Exception:
pass
# 2.) Wait for all running tasks to complete
try:
futures = [f for f in list(self._futures) if not f.done()] # type: ignore
if futures:
self.logger.info(
"Stopping executor, waiting for %d active tasks to complete",
len(futures),
)
wait(futures)
except Exception:
pass
self._should_run_event = None
self._futures = None
def _prep_dask_kwargs(self, extra_context: dict = None) -> dict:
if extra_context is None:
extra_context = {}
dask_kwargs = {"pure": False} # type: dict
# set a key for the dask scheduler UI
key = _make_task_key(**extra_context)
if key is not None:
dask_kwargs["key"] = key
# infer from context if dask resources are being utilized
task_tags = extra_context.get("task_tags", [])
dask_resource_tags = [
tag for tag in task_tags if tag.lower().startswith("dask-resource")
]
if dask_resource_tags:
resources = {}
for tag in dask_resource_tags:
prefix, val = tag.split("=")
resources.update({prefix.split(":")[1]: float(val)})
dask_kwargs.update(resources=resources)
return dask_kwargs
def __getstate__(self) -> dict:
state = self.__dict__.copy()
state.update(
{
k: None
for k in [
"client",
"_futures",
"_should_run_event",
"_watch_dask_events_task",
]
}
)
return state
def __setstate__(self, state: dict) -> None:
self.__dict__.update(state)
def submit(
self, fn: Callable, *args: Any, extra_context: dict = None, **kwargs: Any
) -> "Future":
"""
Submit a function to the executor for execution. Returns a Future object.
Args:
- fn (Callable): function that is being submitted for execution
- *args (Any): arguments to be passed to `fn`
- extra_context (dict, optional): an optional dictionary with extra information
about the submitted task
- **kwargs (Any): keyword arguments to be passed to `fn`
Returns:
- Future: a Future-like object that represents the computation of `fn(*args, **kwargs)`
"""
if self.client is None:
raise ValueError("This executor has not been started.")
kwargs.update(self._prep_dask_kwargs(extra_context))
if self._should_run_event is None:
fut = self.client.submit(fn, *args, **kwargs)
else:
fut = self.client.submit(
_maybe_run, self._should_run_event.name, fn, *args, **kwargs
)
self._futures.add(fut)
return fut
def wait(self, futures: Any) -> Any:
"""
Resolves the Future objects to their values. Blocks until the computation is complete.
Args:
- futures (Any): single or iterable of future-like objects to compute
Returns:
- Any: an iterable of resolved futures with similar shape to the input
"""
if self.client is None:
raise ValueError("This executor has not been started.")
return self.client.gather(futures)
class LocalDaskExecutor(Executor):
"""
An executor that runs all functions locally using dask
and a configurable
dask scheduler.
Args:
- scheduler (str): The local dask scheduler to use; common options are
"threads", "processes", and "synchronous". Defaults to "threads".
- **kwargs (Any): Additional keyword arguments to pass to dask config
"""
def __init__(self, scheduler: str = "threads", **kwargs: Any):
self.scheduler = self._normalize_scheduler(scheduler)
self.dask_config = kwargs
self._pool = None # type: Optional[multiprocessing.pool.Pool]
super().__init__()
@staticmethod
def _normalize_scheduler(scheduler: str) -> str:
scheduler = scheduler.lower()
if scheduler in ("threads", "threading"):
return "threads"
elif scheduler in ("processes", "multiprocessing"):
return "processes"
elif scheduler in ("sync", "synchronous", "single-threaded"):
return "synchronous"
else:
raise ValueError(f"Unknown scheduler {scheduler!r}")
def __getstate__(self) -> dict:
state = self.__dict__.copy()
state["_pool"] = None
return state
def __setstate__(self, state: dict) -> None:
self.__dict__.update(state)
def _interrupt_pool(self) -> None:
"""Interrupt all tasks in the backing `pool`, if any."""
if self.scheduler == "threads" and self._pool is not None:
# `ThreadPool.terminate()` doesn't stop running tasks, only
# prevents new tasks from running. In CPython we can attempt to
# raise an exception in all threads. This exception will be raised
# the next time the task does something with the Python api.
# However, if the task is currently blocked in a c extension, it
# will not immediately be interrupted. There isn't a good way
# around this unfortunately.
import platform
if platform.python_implementation() != "CPython":
self.logger.warning(
"Interrupting a running threadpool is only supported in CPython, "
"all currently running tasks will continue to completion"
)
return
self.logger.info("Attempting to interrupt and cancel all running tasks...")
import sys
import ctypes
# signature of this method changed in python 3.7
if sys.version_info >= (3, 7):
id_type = ctypes.c_ulong
else:
id_type = ctypes.c_long
for t in self._pool._pool: # type: ignore
ctypes.pythonapi.PyThreadState_SetAsyncExc(
id_type(t.ident), ctypes.py_object(KeyboardInterrupt)
)
@contextmanager
def start(self) -> Iterator:
"""Context manager for initializing execution."""
# import dask here to reduce prefect import times
import dask.config
from dask.callbacks import Callback
from dask.system import CPU_COUNT
class PrefectCallback(Callback):
def __init__(self): # type: ignore
self.cache = {}
def _start(self, dsk): # type: ignore
overlap = set(dsk) & set(self.cache)
for key in overlap:
dsk[key] = self.cache[key]
def _posttask(self, key, value, dsk, state, id): # type: ignore
self.cache[key] = value
with PrefectCallback(), dask.config.set(**self.dask_config):
if self.scheduler == "synchronous":
self._pool = None
else:
num_workers = dask.config.get("num_workers", None) or CPU_COUNT
if self.scheduler == "threads":
from multiprocessing.pool import ThreadPool
self._pool = ThreadPool(num_workers)
else:
from dask.multiprocessing import get_context
context = get_context()
self._pool = context.Pool(num_workers)
try:
exiting_early = False
yield
except BaseException:
exiting_early = True
raise
finally:
if self._pool is not None:
self._pool.terminate()
if exiting_early:
self._interrupt_pool()
self._pool.join()
self._pool = None
def submit(
self, fn: Callable, *args: Any, extra_context: dict = None, **kwargs: Any
) -> "dask.delayed":
"""
Submit a function to the executor for execution. Returns a `dask.delayed` object.
Args:
- fn (Callable): function that is being submitted for execution
- *args (Any): arguments to be passed to `fn`
- extra_context (dict, optional): an optional dictionary with extra
information about the submitted task
- **kwargs (Any): keyword arguments to be passed to `fn`
Returns:
- dask.delayed: a `dask.delayed` object that represents the
computation of `fn(*args, **kwargs)`
"""
# import dask here to reduce prefect import times
import dask
extra_kwargs = {}
key = _make_task_key(**(extra_context or {}))
if key is not None:
extra_kwargs["dask_key_name"] = key
return dask.delayed(fn, pure=False)(*args, **kwargs, **extra_kwargs)
def wait(self, futures: Any) -> Any:
"""
Resolves a (potentially nested) collection of `dask.delayed` object to
its values. Blocks until the computation is complete.
Args:
- futures (Any): iterable of `dask.delayed` objects to compute
Returns:
- Any: an iterable of resolved futures
"""
# import dask here to reduce prefect import times
import dask
# dask's multiprocessing scheduler hardcodes task fusion in a way
# that's not exposed via a `compute` kwarg. Until that's fixed, we
# disable fusion globally for the multiprocessing scheduler only.
# Since multiprocessing tasks execute in a remote process, this
# shouldn't affect user code.
if self.scheduler == "processes":
config = {"optimization.fuse.active": False}
else:
config = {}
with dask.config.set(config):
return dask.compute(
futures, scheduler=self.scheduler, pool=self._pool, optimize_graph=False
)[0]
https://github.com/ggservice007/prefect/blob/ray-executor/src/prefect/executors/ray.py
import asyncio
import logging
import uuid
import sys
import weakref
from contextlib import contextmanager
from typing import Any, Callable, Iterator, TYPE_CHECKING, Union, Optional
from prefect import context
from prefect.executors.base import Executor
from prefect.utilities.importtools import import_object
class RayExecutor(Executor):
def __init__(
self):
pass
@contextmanager
def start(self) -> Iterator[None]:
pass
def wait(self, futures: Any) -> Any:
pass
@amogkam @richardliaw any idea? any guideline? how to do?
This looks great.
This would be very useful! any status updates?
I am developing the my-happy-flow project which only support Ray.
Hi, I'm a bot from the Ray team :)
To help human contributors to focus on more relevant issues, I will automatically add the stale label to issues that have had no activity for more than 4 months.
If there is no further activity in the 14 days, the issue will be closed!
You can always ask for help on our discussion forum or Ray's public slack channel.
Hi again! The issue will be closed because there has been no more activity in the 14 days since the last message.
Please feel free to reopen or open a new issue if you'd still like it to be addressed.
Again, you can always ask for help on our discussion forum or Ray's public slack channel.
Thanks again for opening the issue!
Can we reopen this please? Its still relevant.
Can we reopen this please? Its still relevant.
I agree!
Ray is great. Prefect is also great. Combining them is double awesome.
Prefect
The easiest way to automate your data.
Prefect is the new standard in dataflow automation, trusted to build, run, and monitor millions of data workflows and pipelines.
Prefect workflows are designed to be run at any time, for any reason, with any concurrency. Running a workflow on a schedule is a special case.
https://www.prefect.io/
https://github.com/PrefectHQ/prefect
But Prefect is based on Dask executors
Change Dask to Ray executor, any idea?
https://github.com/PrefectHQ/prefect/issues/3929