dask / distributed

A distributed task scheduler for Dask
https://distributed.dask.org
BSD 3-Clause "New" or "Revised" License
1.57k stars 717 forks source link

Task co-assignment logic is worst-case for binary operations like `a + b` #6597

Open gjoseph92 opened 2 years ago

gjoseph92 commented 2 years ago

The root task co-assignment logic does the exact opposite of what it should for operations combining two different datasets, like a + b.

 x     x     x     x
 /\    /\    /\    /\
a  b  a  b  a  b  a  b
1  2  3  4  5  6  7  8  <-- priority

It assigns all the as to one worker, and all the bs to another. Each x then requires transferring an a or a b. So 50% of the data gets transferred. This could have been 0% if we had co-assigned properly.

The reason for this is that co-assignment selects a worker to re-use per task group. So it goes something like (recall that we're iterating through root tasks in priority order):

  1. Assign a1. It has no last_worker set, so pick the least busy worker: w1.
  2. Assign b1. It has no last_worker set, so pick the least busy worker. w1 already has a task assigned to it (a1), so we pick w2.
  3. Assign the next 3 as to w1, and the next 3 bs to w2 (they come through interleaved, since they're interleaved in priority order)
  4. Time to pick a new worker for a5. They're both equally busy; say we pick w2.
  5. Time to pick a new worker for b5. We just made w2 slightly busier than w1, so pick w1.
  6. Pattern continues. Each time we flip-flop, sending the tasks to opposite workers

The last-used worker should be global state (well, global to a particular sequence of transitions caused by update_graph). Each subsequent task in priority should re-use this worker until it's filled up, regardless of what task group the task belongs to.

The tricky part is calculating what "filled up" means. We currently use the size of the task group to decide how many root tasks in total there are, which we then divide by nthreads to decide how many to assign per worker. But of course, that's not actually the total number of root tasks. I'm not sure yet how to figure out the total number of root tasks in constant time within decide_worker.

Broadly speaking, this stateful and kinda hacky co-assignment logic is a bit of a pain to integrate into https://github.com/dask/distributed/issues/6560. I've been able to do it, but maintaining good assignment while rebalancing tasks when adding and removing workers is difficult. Our co-assignment logic is too reliant on statefulness and getting to iterate through all the tasks at once in priority order, we can't actually re-co-assign things when workers change. If we had a data structure/mechanism to efficiently identify "which tasks are siblings of this one", or maybe even "which worker holds the task nearest in priority to this one", it might make solving both problems easier.


As a simple test that fails on main (each worker has transferred 4 keys):

@gen_cluster(
    client=True,
    nthreads=[("", 1), ("", 1)],
)
async def test_decide_worker_coschedule_order_binary_op(c, s, a, b):
    xs = [delayed(i, name=f"x-{i}") for i in range(8)]
    ys = [delayed(i, name=f"y-{i}") for i in range(8)]
    zs = [x + y for x, y in zip(xs, ys)]

    await c.gather(c.compute(zs))

    assert not a.transfer_incoming_log, [l["keys"] for l in a.transfer_incoming_log]
    assert not b.transfer_incoming_log, [l["keys"] for l in b.transfer_incoming_log]

Note that this case occurs in @TomNicholas's example workload: https://github.com/dask/distributed/issues/6571

cc @fjetter @mrocklin

fjetter commented 2 years ago

https://github.com/dask/distributed/pull/6985