Open crusaderky opened 2 years ago
Overall, I like this proposal, and I think tasks-within-tasks is certainly in need of a redesign.
However, one downside of the "futures returned from a task are always awaited" rule is that it eliminates some of the flexibility which might be your reason for using tasks-within-tasks in the first place.
Specifically, there may be other ways in which you want to wait for the futures to be done:
as_completed
(you can start processing them as soon as one/a few are done, not all)For example, you couldn't do something like:
import distributed
def get_all_pages(x: Collection) -> list[PageId]: ...
def process_page(id: PageId) -> Page: ...
def combine_pages(pages: list[Page]) -> Summary: ...
def summarize(x: Collection, group_size: int = 4) -> list[distributed.Future[Summary]]:
pages = get_all_pages(x)
client = distributed.get_client()
page_futures = client.map(process_page, x)
distributed.secede()
done_pages = []
summary_futures = []
for f in distributed.as_completed(result_futures):
done_pages.append(f)
if len(done) == group_size:
summary_futures.append(client.submit(combine_pages, done_pages))
done_pages.clear()
return summary_futures
client = distributed.Client(...)
summarize_future = client.submit(summarize, x)
summary_futures = summarize_future.result()
# Note that `summarize` returns a list of Futures, not the actual Summaries,
# so that we can stream them to our Real Time Business Intelligence System
for f, summary in distributed.as_completed(summary_futures, with_results=True):
display_summary_on_dashboard(summary)
I do like the simplicity of your proposal, though. And these extra-complex use cases may be worth giving up.
I do think Ray is worth looking at for prior art here, as a system that handles tasks-within-tasks as a core use-case, instead of an edge case like it is for distributed. The API is extremely simple, but I think belies some careful thinking about the rules needed to make this work well in a distributed context. (For example, you can submit and wait for futures without managing any of the secede
/rejoin
logic yourself—still trying to find a good reference for how that works.)
@gjoseph92 your example currently does not work.
When summarize
returns, it will (likely) close its client. The summary futures are likely to be forgotten by the scheduler because the only client holding a reference to them has been shut down. It ends in a race condition on what's fastest, the garbage collection of the worker client or the return value going all the way back to the user client. If there was a circular reference in the worker client, then you may as well experience that everything works, because it takes a while before the next gc run, until you increase the load on the worker and suddenly gc runs faster than your return value. Alternatively, if you run in a multithreaded worker the same Client instance may be used by multiple threads; as long as you have 2+ threads holding a reference to the Client at all times, it will work, but as soon as you end up with a single thread the Client will be garbage collected. For an inexperienced user, this is a nightmare to debug.
Also, in your example the dashboard will not display any updates until all pages to be summarized are complete. You could rely on the (not always true) assumption that futures are completed more or less in FIFO order and just immediately schedule a summary every 4 page futures - which is what all of the dask recursive aggregations do.
To me, you're highlighting two different issues:
as_completed
is insufficient on its own. I personally would like to reopen the discussion on distributed.Queue
. A new feature like "spawn a combine_pages
task whenever there are 4 or more elements in the queue" would nicely solve your use case.@gjoseph92 I tried rewriting your use case in a way that works today, and what I got is very complicated and brittle.
import asyncio
import distributed
def get_all_pages(x: Collection) -> list[PageId]: ...
def process_page(id: PageId) -> Page: ...
def combine_pages(pages: list[Page]) -> Summary: ...
def summarize(x: Collection, q: distributed.Queue, group_size: int = 4) -> None:
pages = get_all_pages(x)
client = distributed.get_client()
page_futures = client.map(process_page, pages)
del pages
distributed.secede()
done_pages = []
for f in distributed.as_completed(page_futures):
done_pages.append(f)
if len(done_pages) == group_size:
q.put(client.submit(combine_pages, done_pages))
done_pages = []
q.put(None)
async def summarize_collection(x: Collection):
client = await distributed.Client(..., asynchronous=True)
q = await distributed.Queue()
summarize_future = client.submit(summarize, x, q)
# Convert distributed.Future to asyncio.Future
summarize_future = asyncio.create_task(summarize_future.result())
done = False
pending = {summarize_future}
queue_get_future = None
while True:
if queue_get_future is None and not done:
queue_get_future = asyncio.create_task(q.get())
pending.add(queue_get_future)
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for f in done:
if f is queue_get_future:
queue_output = await f
if queue_output is None:
done = True
else:
assert isinstance(queue_output, distributed.Future)
# Convert distributed.Future to asyncio.Future
pending.add(asyncio.create_task(queue_output.result()))
queue_get_future = None
elif f is summarize_future:
# If summarize crashed remotely, bomb out instead of getting stuck
# forever on a queue that will never get closed
await f
else:
summary = await f
display_summary_on_dashboard(summary)
I cooked a mock-up that requires a new class, distributed.FutureStream
:
from functools import partial
import distributed
def get_all_pages(x: Collection) -> list[PageId]: ...
def process_page(id: PageId) -> Page: ...
def combine_pages(pages: list[Page]) -> Summary: ...
def process_pages(
x: Collection,
pages_stream: distributed.FutureStream[Page],
) -> None:
pages = get_all_pages(x)
client = distributed.get_client()
page_futures = client.map(process_page, pages)
for fut in page_futures:
pages_stream.put(fut)
pages_stream.close()
def summarize(
futures: list[distributed.Future],
pages: list[Page],
summaries_stream: distributed.FutureStream[Summary],
) -> None:
# TODO handle exceptions from futures
if futures:
client = distributed.get_client()
summaries_stream.put(client.submit(combine_pages, pages))
else:
summaries_stream.close()
def summarize_collection(x: Collection):
client = distributed.Client(...)
pages_stream = distributed.FutureStream(order=False)
summaries_stream = distributed.FutureStream(order=False)
pages_stream.add_done_callback(
partial(summarize, summaries_stream=summaries_stream),
with_result=True,
remote=True,
chunk_size=4,
)
client.submit(process_pages, x, pages_stream).result()
try:
while True:
summary = summaries_stream.get()
display_summary_on_dashboard(summary)
except distributed.StreamClosed:
pass
This... works and is reasonably clean, however it serves specifically the use case of a multi-stage as_completed stream. I wonder how much real-life user appetite there is for it?
To chip in a bit, another application that would be covered by the original proposal, but not by a future stream is iterative optimization. With the approach of returning futures it would go as:
def optimization_step(current):
if converged: return current
next = guess_next(current)
return client.compute(optimization_step(next))
As far as I understand, this doesn't seem addressed by a FutureStream
.
your example currently does not work
Apologies, I should have been clear: my example code wasn't meant to work, just to illustrate a use-case. I was imagining some semantic hybrid between the current distributed API and the Ray API here.
The point I was trying to make was just that aways awaiting returned futures restricts some use-cases. In particular, this sort of result streaming is either really tricky, or requires a special structure like FutureStream
.
My hypothesis was just that if we could make it so that my code worked (correctly handling the reference-counting/ownership of futures-pointing-to-futures), then I think tasks-within-tasks would still be easy to use, but strictly more powerful than in your auto-await proposal, since you could control how the awaiting happened to easily implement things like streaming yourself, without any additional structures.
Whether it's worth the effort to make this reference-counting work correctly, I really can't say.
But since we're talking about ways to do this, here's my original code, converted to Ray, and actually working. You can see it's still quite simple and readable:
@ray.remote
def summarize(x: Collection, group_size: int = 4) -> list[ray.ObjectRef[Summary]]:
pages = get_all_pages(x)
page_oids = [process_page.remote(p) for p in pages]
summary_oids: list[ray.ObjectRef[Summary]] = []
while page_oids:
done, page_oids = ray.wait(
page_oids, num_returns=min(group_size, len(page_oids)), fetch_local=False
)
summary_oids.append(combine_pages.remote(*done))
del done
return summary_oids
summarize_future = summarize.remote(Collection("abcdefghijklmnopqrstuvwxyz"))
summary_futures = ray.get(summarize_future)
# Note that `summarize` returns a list of ObjectIDs, not the actual Summaries,
# so that we can stream them to our Real Time Business Intelligence System
while summary_futures:
[summary], summary_futures = ray.wait(summary_futures)
display_summary_on_dashboard(ray.get(summary))
I think the reason this works in Ray but doesn't in distributed is because in
summarize_future = summarize.remote(Collection("abcdefghijklmnopqrstuvwxyz"))
summary_futures = ray.get(summarize_future)
Ray is able to track that summarize_future
is a pointer to a pointer / future to a future, and therefore the inner futures are "owned"/kept alive by the outer future. As you mentioned @crusaderky, distributed can't track this—it's just a race condition whether the worker client releases the inner futures before the user client can dereference the outer futures and pick up references to the inner futures.
I don't think we have a way to represent this dependency structure right now in dask? Currently, if a task is complete, all its dependencies must be complete. This would be a new situation, where a task (summarize
) is complete, yet in the process of running, it added more "dependencies" (not dependencies—we'd need a different term) to itself, which are not yet complete. So because your client wants summarize_future
, and summarize_future
"depends on" all of the summary_futures
, the summary_futures
are kept alive even though no client wants them directly.
Maybe we could add wants_what
(and tasks_who_want
) to scheduler.TaskState
, as a corollary to ClientState.wants_what
? So tasks could pin references to other keys, in the same way that Clients can? That might be sufficient to track this nested ownership. Any keys returned by a task (including traversing lists/tuples/dicts) would be added to wants_what
. Keys that were wanted—either by a client, or another task—would not be released. Therefore, by holding a reference to the outer future on your client, you'd transitively keep any tree of keys that it points to alive.
Though we might then want checks for ownership reference cycles, which would be interesting.
With some cleverness, you could even implement some sort of chain-fusion/de-aliasing on the scheduler, where you collapse linear chains of futures down to just the first and last element. I'm thinking about this because, in the example of recursion @akhmerov posted, the recursion will potentially produce a huge number of futures, which serve no purpose other than to point to the next future. Both tracking this on the scheduler, and traversing the many Future objects to get the final result, could be expensive. But this is a classic example of where you'd want tail-call optimization. And I think you might be able to implement it on the scheduler with a chain-collapsing rule.
here's my original code, converted to Ray, and actually working
You're still not executing any calls to display_summary_on_dashboard
until all calls to process_page
have been completed though.
I confess my ignorance with ray - do I understand correctly that
done, pending = ray.wait(...)
is the same as
distributed.secede()
done, pending = distributed.wait(...)
distributed.rejoin()
? Or is there some python interpreter magic going on where what looks like a synchronous function is actually asynchronous?
yet in the process of running, it added more "dependencies" (not dependencies—we'd need a different term) to itself
consequences?
Yes, ray.wait
/ ray.get
basically handles the secede
/rejoin
for you automatically. I haven't read their implementation, but the end result is that as a user, you can just use it without worrying about deadlocks.
consequences?
Not sure what you're asking here. The consequences of a task's result (a Future) depending on other tasks are exactly what you talked about:
It ends in a race condition on what's fastest, the garbage collection of the worker client or the return value going all the way back to the user client.
If we codify the fact that a task's result can depend on other tasks (adding wants_what
to TaskState
as I'm suggesting), then this is all tracked on the scheduler, there's no more race condition, and returning Futures from tasks has first-class support.
consequences?
Not sure what you're asking here.
I'm asking if you like the idea to call these spawned tasks "consequences" since we agree they are not dependencies.
Ah I see. Yes, that makes sense as a user-facing term maybe. I think I'd prefer wants_what
on the TaskState
though, since it matches with ClientState.wants_what
, and is effectively the same thing.
Hello, 😄
I'm a newer user of Dask and currently working on a specific workflow where I need to process a huge amount of documents and stumbled on this opened issue.
The basic idea is to open a huge amount of documents. Each document is a collection of an arbitrary number of pages. We then apply a OCR on each page. The time to process each page is arbitrary and we have a lot more pages than documents ( 20M pages for 400 000 documents).
I tried using Dask for this pipeline and tried different ways of writing this pipeline and I settled on the task in tasks design pattern like @gjoseph92 showed in the mock example :
```python from time import sleep from distributed import LocalCluster, Client from distributed import Queue def ocr_image(page): timeDelay = random.randrange(1,10) sleep(timeDelay) # simulate actual ocr work return "this is ocr" def load_pages(doc): # simulate open file sleep(0.5) futures=[] n = random.randint(1,5) with worker_client() as client: for page in range(n): future_ocr = client.submit(ocr_image,page,pure=False) futures.append(future_ocr) return futures def main(): # Load and submit tasks loaders= [ client.submit(load_pages,doc,pure=False) for doc in filenames] res_loaders = client.gather(loaders) res_ocr = client.gather(list(chain.from_iterable(res_loaders))) return res_ocr ```
The issue with this approach is having to schedule a LOT of small tasks, so I thought about batching but the issue here is the arbitrary number of pages in a document. ( 1 pages to 40000 !)
The correct approach in my humble opinion would be to have a distributed producer/consumer architecture with a distributed queue like of pages that we can consume.
I tried distribute.Queue
class with a wait for the first but it has some major issues if you don't know the exact number of spawned tasks in tasks :
```python def batch_ocr_image(): # You can't have batch size and timeout # pages = [ q.get(timeout='1s') for _ in range(batch_size)] pages = q.get(batch_size) for _ in range(batch_size) : timeDelay = random.randrange(1,10) sleep(timeDelay) # simulate actual ocr work return ["this is ocr"]*batch_size def ocr_image(): page = q.get(timeout='1s') timeDelay = random.randrange(1,10) sleep(timeDelay) return "this is ocr" def load_pages(doc): # simulate open file sleep(0.5) futures=[] n = random.randint(1,5) n = 10 for page in range(n): q.put(page) return n def main(): ## Load pages in queue loaders= [ client.submit(load_pages,doc,pure=False) for doc in filenames] # Sync 1 : Gather loaders # approach 1 : wait for all loaders to finish res_loaders = client.gather(loaders) # approach 2 : wait for the first and then submit loaders = wait(loaders,return_when='FIRST_COMPLETED') ## Batching # Batching is very hard : q.qsize() will fail here consumers = [client.submit(batch_ocr_image,pure=False,retries=4) for _ in range(q.qsize()//batch_size)] # Sync 2 : to consume queue res_consumer = client.gather(consumers) return loaders, res_consumer ```
I might miss something about how to correctly implement the producer/consumer using distributed.
The distributed.FutureStream
proposed by @crusaderky could solve the lack of easy to implement producer/consumer pattern in dask.distributed.
I have just submitted a feature request with the needed methods for a proper distributed queue in my opinion based on the existing class :
del q
doesnt actually free up the distributed memoryq.join()
in distributed manner to block until all the queue items have been processedbatch_size
if the queue is emptyI don't know the internal design of distributed but submitting futures might the issue with derefencing when the root function return, the future continues to live until q.get()
is called, we can the either block the consumer until the result is done .
@AmineDiro if you just had one task per document (400k tasks) and then process each page sequentially inside the same task, would you incur in a substantially suboptimal behaviour? I assume here that the number of workers you use is a tiny fraction of 400k.
Even if you do need to spawn futures for the pages, wouldn't it be covered by my design in the opening post?
(20M pages for 400 000 documents)
I'm afraid that, as of today, if you have 20M tasks you will likely hit 100% CPU load on the scheduler - and consequently experience a wealth of random timeouts. I would advise to perform some clustering.
Thanks for the response @crusaderky ! Sequentially processing pages of each document will be bottlenecked by the workers with documents of 40000 pages... I am using an HPC with 100 worker of 24 cores each. You are right, I did see a 100% cpu load on the scheduler. I would like to batch page but I need an object that stores opened pages in a buffer before submitting them, a queue would do the job....
@AmineDiro your problem can be solved by the design in the initial post; no need for queues:
CHUNK_SIZE = 1000 # pages processed by a single task
def parse_document(path: str) -> list[Image]:
# Load document from disk and return one raw image per page
...
def ocr_page(page: Image) -> OCRExitStatus:
# Run a single page through OCR, dump output to disk, and return metadata / useful info
...
@delayed
def ocr_pages(pages: list[Image]) -> list[OCRExitStatus]:
return [ocr_page(page) for page in pages]
@delayed
def aggregate_ocr_results(*chunk_results: list[OCRExitStatus]) -> list[OCRExitStatus]:
return [r for chunk in chunk_results for r in chunk]
@delayed
def ocr_document(doc_path: str):
raw_pages = parse_document(path)
client = distributed.get_client()
chunks = client.scatter(
[
raw_pages[i: i + CHUNK_SIZE]
for i in range(0, len(raw_pages), CHUNK_SIZE)
]
)
return aggregate_ocr_results(ocr_pages(chunk) for chunk in chunks)
client = distributed.Client()
all_results = aggregate_ocr_results(ocr_document(path) for path in paths)
all_results.compute() # returns list[OCRExitStatus]
Also note that, WITHOUT this change, you can achieve today what you want with a slightly less efficient two-stage approach:
CHUNK_SIZE = 1000 # pages processed by a single task
def count_pages(path: str) -> int:
# Open document, peek at the header, and return number of pages contained within
...
@delayed
def parse_document(path: str) -> list[Image]:
# Fully load document from disk and return one raw image per page
...
def ocr_page(page: Image) -> OCRExitStatus:
# Run a single page through OCR, dump output to disk, and return metadata / useful info
...
@delayed
def ocr_pages(pages: list[Image]) -> list[OCRExitStatus]:
return [ocr_page(page) for page in pages]
@delayed
def aggregate_ocr_results(*chunk_results: list[OCRExitStatus]) -> list[OCRExitStatus]:
return [r for chunk in chunk_results for r in chunk]
client = distributed.Client()
npages = client.gather(client.map(count_pages, paths))
ocr_chunk_delayeds = []
for path, npages_i in zip(paths, npages):
raw_pages_delayed = parse_document(path)
ocr_chunk_delayeds += [
ocr_pages(raw_pages_delayed[i: i + CHUNK_SIZE])
for i in range(0, npages_i, CHUNK_SIZE)
]
all_results = aggregate_ocr_results(*ocr_chunk_delayeds)
all_results.compute() # returns list[OCRExitStatus]
Amazing ! Thanks a lot @crusaderky for taking the time to write up this code.
I have had some issues with resolving the delayed objects from the aggregate_ocr_results
. The .compute()
never computes the actual ocr and return a List[Delayed]
? Is there something I'm missing ?
I didn't think about using this approach because I taught that I need to chunk pages across documents and not within a single doc, but I see how this could work !
The discussion for a producer/consummer is bit more general, and gives a "cleaner" way to solve these kinds of problems
I have had some issues with resolving the delayed objects from the
aggregate_ocr_results
. The.compute()
never computes the actual ocr and return aList[Delayed]
? Is there something I'm missing ?
The first block of code in my previous post does not work today; it requires the change described in the op.
@crusaderky Ok got it ! the first one needs the change you mentionned above.
I think that the second one also needs that change to work.
client = distributed.Client()
npages = client.gather(client.map(count_pages, paths))
ocr_chunk_delayeds = []
for path, npages_i in zip(paths, npages):
raw_pages_delayed = parse_document(path)
ocr_chunk_delayeds += [
ocr_pages(raw_pages_delayed[i: i + CHUNK_SIZE])
for i in range(0, npages_i, CHUNK_SIZE)
] ### Added the delayed call here to ocr_page
all_results = aggregate_ocr_results(*ocr_chunk_delayeds)
all_results.compute() # returns list[OCRExitStatus] ==> **Returns a LIST[Delayed]**
The issue is that I still need to wait on another compute to get pages.
result = client.gather(client.compute(all_results.compute()))
Thanks for your help
@AmineDiro there was 1 line wrong
ocr_chunk_delayeds += [
- raw_pages_delayed[i: i + CHUNK_SIZE]
+ ocr_pages(raw_pages_delayed[i: i + CHUNK_SIZE])
for i in range(0, npages_i, CHUNK_SIZE)
]
I think that the second one also needs that change to work.
It doesn't. I just tested with mocked functions that it does work today as intended.
result = client.gather(client.compute(all_results.compute()))
This is not correct. The compute()
method already returns the final output.
For a single Delayed object such as this, these are all equivalent, as long as you're using a synchronous Client:
all_results.compute()
client.compute(all_results).result()
client.gather(client.compute(all_results))
Haven't read through the full issue, but just want to say am happy to see this discussion 😄
In particular this stuck out to me...
The same should be implemented in dask/dask, so that it works on the threading/multiprocessing schedulers too.
Couldn't agree more. There are actual use cases we could solve with improved nested task submission (especially if all schedulers can support it as it is easier to add that code into Dask itself).
Dask performs best when the whole graph of tasks is defined ahead of time from the client and submitted at once. This is however not always possible.
Use case and current best practice
A typical example is a top-down discovery + bottom-up aggregation of a tree where the discovery of the parent-child relationships is an operation too expensive to be performed on the client.
Use case in pure Python:
The first way to solve this problem today with Dask is to have the client invoke
client.submit(get_children, node)
for every node and wait for results. This can be very network and CPU intensive for the client.The second approach is to use secede/rejoin, as described in https://distributed.dask.org/en/latest/task-launch.html:
The above is problematic, because:
and last but not least,
A situationally slightly better variant is as follows:
The difference is subtle - as the subgraph of each child gets resolved, its (potentially large) output does not get stored in the stack of client.aggregate (which is unmanaged memory), but it goes into the managed memory instead with all the benefits of the case. On the flip side, the scheduler is now burdened with two futures per node instead of one. Regardless, all of the problems listed above remain.
Proposed redesign
I would like to suggest deprecating secede()/rejoin(). In its place, I would like to introduce the following rule:
If a task returns a Future, then the scheduler will wait for it and return its result instead. This may be nested (the result of the Future may itself be a Future).
The use case code becomes as follows:
No extra threads are ever created. Everything is managed by the scheduler - as it should. The network and CPU load on the user's client (e.g. a jupyter notebook) remain trivial.
Nested resolution of futures aside, the above code currently does not work because, after you return a future, as soon as the future is serialised and removed from Worker.data the future destructor kicks in, which in turn releases the refcount on the scheduler, so by the time the future is rebuilt on the opposite side the data it references may have been lost. Same if the future is spilled to disk.
Challenges
Publish/unpublish
It is currently possible to work around the scheduler forgetting the future upon return by publishing it temporarily. This is generally a bad idea because, short of implementing a user-defined garbage collector, you may end up with cluster-wide memory leaks of managed memory (datasets that are published and then forgotten, because the task that was supposed to unpublish them crashed or never started). Nonetheless, automatically resolving returned futures will break this pattern.
Workaround
Users can still use this hack but return the name of the temporary dataset instead of the Future.
Additions and nice-to-haves
Client-side tracking
It would be nice to see
distributed.diagnostics.progressbar.progress
display the increasing tasks in real time. This is not something that's happening with the current secede/rejoin design either.Collections
Returned dask collections could be treated specially like Futures. For example, the below would halve the number of worker->scheduler comms and (personal preference) would also look nicer:
Under the hood, all it's happening is a two-liner that converts the collection into a future to revert to the base use case:
The same should be implemented in dask/dask, so that it works on the threading/multiprocessing schedulers too.