dask / dask

Parallel computing with task scheduling
https://dask.org
BSD 3-Clause "New" or "Revised" License
12.55k stars 1.71k forks source link

Forget history and task overhead investigation #4630

Open rabernat opened 5 years ago

rabernat commented 5 years ago

Another issue about dask performance and optimization, slightly related to #107.

I frequently end up creating dask graphs with 1M+ tasks. Graphs this big cause the scheduler to start to choke. One idea I have had to mitigate this is to call .persist() on some intermediate results. I would essentially like to save some results in the memory of my dask cluster, and then do further computations on this data.

However, the problem is that .persist() doesn't seem to reduce the number of tasks in the scheduler's memory. Even though the data I need are all memory on the worker nodes, I can't erase the expensive task history.

Specifically I would like to do something like this

raw_data = load_data_from_storage(...)
intermediate_result = big_function_that_creates_lots_of_tasks(raw_data)
intermediate_result.persist(forget=True)
# intermediate result now only has 1 task per chunk
final_result = downstream_function(intermediate_result)

An alternative way to phrase this is that I would like to use the dask cluster as an in-memory distributed storage object, so perhaps there is a different way to achieve the same result.

mrocklin commented 5 years ago

Two questions:

  1. Have high level expression graphs had an impact on your work? This came about in https://github.com/dask/dask/issues/4038 as a result of issues that you were having. It would be good to hear if that had any effect.
  2. What would you want to happen if a worker goes down holding the last copy of some data? Currently we need the task graph in order to provide resilience. Would we want to duplicate all data a few times to different workers to ensure resilence?

Another approach would be to store your data onto some persistent store like GCS, perhaps using the da.store(..., return_stored=True) function (though this lacks the ephemeral nature that you're probably looking for).

rabernat commented 5 years ago

@mrocklin - thanks for your quick response.

  1. Have high level expression graphs had an impact on your work?

I have not seen any measurable effect from this. Here is the simplest benchmark I can think of with real data.

import intake
import xarray as xr
catalog_url = 'https://github.com/pangeo-data/pangeo/raw/master/gce/catalog.yaml'
ds = intake.Catalog(catalog_url).LLC4320_SSU.to_dask()

# simple dummy calculation I thought would be able to target high-level graphs
uf = (5*ds.U**2 + 1).mean()
# I assume this is the low level graph? It's the same in all dask versions
assert len(uf.data.dask) == 630981
# time how long it takes to send the graph to the scheduler
%time uf.persist()

Dask 1.0.0 (before #4092):

CPU times: user 36.5 s, sys: 1.57 s, total: 38.1 s
Wall time: 37.3 s

Dask 1.1.4

CPU times: user 46.1 s, sys: 2.05 s, total: 48.2 s
Wall time: 46.9 s

Am I doing something wrong? Bottom line is that we are still waiting a long time for large graphs, even with very simple calculations.

2. What would you want to happen if a worker goes down holding the last copy of some data?

Very good point. I had not thought about that. As long as it were explicit, I would be willing to live with the risk of losing my data and having to start over.

mrocklin commented 5 years ago
assert len(uf.data.dask) == 630981

You would need to do something like

(uf2,) = dask.optimize(uf)
len(uf2.data.dask)

I tried the minimal example, but ran into not having the zarr plugin. I'm not sure how best to install plugins in intake. Moving on this morning. I'll try to get back to this in a while (no promises though).

mrocklin commented 5 years ago

Very good point. I had not thought about that. As long as it were explicit, I would be willing to live with the risk of losing my data and having to start over.

That seems in scope then. If someone wanted to try this out they would probably look at how scattered futures are handled in terms of dependencies and such, and then figure out how to get to the same state with other futures.

mrocklin commented 5 years ago

This would require diving into Scheduler state a bit. It's highly unlikely that I'll get to this this week or next, but perhaps the future will be easier.

rabernat commented 5 years ago

@mrocklin - I really appreciate your response on this.

Also, I should have learned by now to never try to use real data in an example!

The data in question can be recreated at the dask level as

import dask.array as ds
dsa.random.random((9000, 13, 4320, 4320), chunks=(1, 1, 4320, 4320))
mrocklin commented 5 years ago

So the number of tasks definitely shrinks

In [1]: import dask

In [2]: import dask.array as da
   ...: x = da.random.random((9000, 13, 4320, 4320), chunks=(1, 1, 4320, 4320))

In [3]: y = (5*x**2 + 1).mean()

In [4]: yy, = dask.optimize(y)

In [5]: len(y.dask)
Out[5]: 628881

In [6]: len(yy.dask)
Out[6]: 277875

But this doesn't address the fact that you're seeing increased compute times. It would be useful to dive into the difference here and profile things. I would probably start with profiling with the single-threaded scheduler, and see what is taking up the most time.

martindurant commented 5 years ago

ran into not having the zarr plugin

Package intake-xarray In the current master, you get a message when the plugin isn't found, pointing you to the page where all the plugins of all the packages are listed. Can you think of a better way? Plugins can be referred to by simple string name or by python import, but in neither case would Intake know where to actually get that thing from.

mrocklin commented 5 years ago

No particular thoughts. I as a user didn't know how to find the right package other than you telling it to me now, or doing a web search (which is fine). Not a huge deal for this issue though. @rabernat was kind enough to provide an example that was simpler.

On Tue, Mar 26, 2019 at 8:04 AM Martin Durant notifications@github.com wrote:

ran into not having the zarr plugin

Package intake-xarray In the current master, you get a message when the plugin isn't found, pointing you to the page where all the plugins of all the packages are listed. Can you think of a better way? Plugins can be referred to by simple string name or by python import, but in neither case would Intake know where to actually get that thing from.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/dask/dask/issues/4630#issuecomment-476690893, or mute the thread https://github.com/notifications/unsubscribe-auth/AASszJU6JTBRm3Dii_bIWWTqi4UtfWu3ks5vajcAgaJpZM4cHFJ1 .

martindurant commented 5 years ago

snakeviz rendering of the persist line:

Screen Shot 2019-04-17 at 11 42 33

rabernat commented 5 years ago

@martindurant - am I correct in interpreting your benchmark that most of the time is spent on communication, i.e. serialization and transmission of the graph?

martindurant commented 5 years ago
Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
  2196                                               def _graph_to_futures(self, dsk, keys, restrictions=None,
  2197                                                                     loose_restrictions=None, priority=None,
  2198                                                                     user_priority=0, resources=None, retries=None,
  2199                                                                     fifo_timeout=0, actors=None):
  2200         1        462.0    462.0      0.0          with self._refcount_lock:
  2201         1         16.0     16.0      0.0              if resources:
  2202                                                           resources = self._expand_resources(resources,
  2203                                                                                              all_keys=itertools.chain(dsk, keys))
  2204
  2205         1          4.0      4.0      0.0              if retries:
  2206                                                           retries = self._expand_retries(retries,
  2207                                                                                          all_keys=itertools.chain(dsk, keys))
  2208
  2209         1          4.0      4.0      0.0              if actors is not None and actors is not True and actors is not False:
  2210                                                           actors = list(self._expand_key(actors))
  2211
  2212         1          6.0      6.0      0.0              keyset = set(keys)
  2213         1         20.0     20.0      0.0              flatkeys = list(map(tokey, keys))
  2214         1         64.0     64.0      0.0              futures = {key: Future(key, self, inform=False) for key in keyset}
  2215
  2216         1      70769.0  70769.0      0.1              values = {k for k, v in dsk.items() if isinstance(v, Future)
  2217                                                                 and k not in keyset}
  2218         1          3.0      3.0      0.0              if values:
  2219                                                           dsk = dask.optimization.inline(dsk, keys=values)
  2220
  2221         1   46136777.0 46136777.0     58.3              d = {k: unpack_remotedata(v, byte_keys=True) for k, v in dsk.items()}
  2222         1      62375.0  62375.0      0.1              extra_futures = set.union(*[v[1] for v in d.values()]) if d else set()
  2223         1          6.0      6.0      0.0              extra_keys = {tokey(future.key) for future in extra_futures}
  2224         1    5802807.0 5802807.0      7.3              dsk2 = str_graph({k: v[0] for k, v in d.items()}, extra_keys)
  2225         1     107090.0 107090.0      0.1              dsk3 = {k: v for k, v in dsk2.items() if k is not v}
  2226         1          3.0      3.0      0.0              for future in extra_futures:
  2227                                                           if future.client is not self:
  2228                                                               msg = ("Inputs contain futures that were created by "
  2229                                                                      "another client.")
  2230                                                               raise ValueError(msg)
  2231
  2232         1          2.0      2.0      0.0              if restrictions:
  2233                                                           restrictions = keymap(tokey, restrictions)
  2234                                                           restrictions = valmap(list, restrictions)
  2235
  2236         1          2.0      2.0      0.0              if loose_restrictions is not None:
  2237         1          6.0      6.0      0.0                  loose_restrictions = list(map(tokey, loose_restrictions))
  2238
  2239         1     918094.0 918094.0      1.2              future_dependencies = {tokey(k): {tokey(f.key) for f in v[1]} for k, v in d.items()}
  2240
  2241    278805     557658.0      2.0      0.7              for s in future_dependencies.values():
  2242    278804     589462.0      2.1      0.7                  for v in s:
  2243                                                               if v not in self.futures:
  2244                                                                   raise CancelledError(v)
  2245
  2246         1    2917822.0 2917822.0      3.7              dependencies = {k: get_dependencies(dsk, k) for k in dsk}
  2247
  2248         1          3.0      3.0      0.0              if priority is None:
  2249         1   12015141.0 12015141.0     15.2                  priority = dask.order.order(dsk, dependencies=dependencies)
  2250         1     526902.0 526902.0      0.7                  priority = keymap(tokey, priority)
  2251
  2252         1          4.0      4.0      0.0              dependencies = {tokey(k): [tokey(dep) for dep in deps]
  2253         1    1889848.0 1889848.0      2.4                              for k, deps in dependencies.items()}
  2254    278805     549853.0      2.0      0.7              for k, deps in future_dependencies.items():
  2255    278804     534366.0      1.9      0.7                  if deps:
  2256                                                               dependencies[k] = list(set(dependencies.get(k, ())) | deps)
  2257
  2258         1         21.0     21.0      0.0              if isinstance(retries, Number) and retries > 0:
  2259                                                           retries = {k: retries for k in dsk3}
  2260
  2261         1          4.0      4.0      0.0              self._send_to_scheduler({'op': 'update-graph',
  2262         1    6404836.0 6404836.0      8.1                                       'tasks': valmap(dumps_task, dsk3),
  2263         1          2.0      2.0      0.0                                       'dependencies': dependencies,
  2264         1          3.0      3.0      0.0                                       'keys': list(flatkeys),
  2265         1          2.0      2.0      0.0                                       'restrictions': restrictions or {},
  2266         1          2.0      2.0      0.0                                       'loose_restrictions': loose_restrictions,
  2267         1          2.0      2.0      0.0                                       'priority': priority,
  2268         1          2.0      2.0      0.0                                       'user_priority': user_priority,
  2269         1          2.0      2.0      0.0                                       'resources': resources,
  2270         1          4.0      4.0      0.0                                       'submitting_task': getattr(thread_state, 'key', None),
  2271         1          2.0      2.0      0.0                                       'retries': retries,
  2272         1          2.0      2.0      0.0                                       'fifo_timeout': fifo_timeout,
  2273         1         87.0     87.0      0.0                                       'actors': actors})
  2274         1          3.0      3.0      0.0              return futures

unpack_remotedata is the major culprit (not the communication, at least not with a local scheduler) - I don't understand what it does. Quite a few iterations over all tasks in the graph here, perhaps a chance to fuse them.

martindurant commented 5 years ago

Specific answer, @rabernat , is that the time is spent mostly in the one method above, but I don't really know what the specific lines do - @mrocklin probably does, not sure who else. Since this is all about looping over an awful lot of python objects multiple times plus a little logic, it may be an excellent case for cythonization - or perhaps there are simpler things that could be done.

mrocklin commented 5 years ago

A lot of this is just general dask optimization and graph manipulation. My knowledge of the distributed scheduler isn't strictly required here. Someone like @jcrist or @eriknw would have good experience here.

In general, it would be good for someone to take another look at our overhead. We've done some biggish changes over the last several months and haven't revisted this topic.

TomAugspurger commented 5 years ago

For my understanding, there are two potential issues here:

  1. High-level graph fusion (maybe?) not helping as much as expected. Something like https://github.com/dask/dask/issues/4630#issuecomment-476686913.
  2. An API for persisting just data, without the task graph. This is likely a keyword to persist?

I can start to poke at these issues, and will take them up when I'm back full time (~2 weeks away).

mrocklin commented 5 years ago

Agreed with the separation. I think that both fixes are good generally.

I recommend prioritizing looking at task overhead, and high level graph fusion in particular. I expect it to be more wide-reaching, and probably more straightforward to investigate.

TomAugspurger commented 5 years ago

@rabernat when you have time, can you check double-check some timings for me?

Using an example like https://github.com/dask/dask/issues/4630#issuecomment-476686913, I don't see a slowdown from dask 1.0 -> dask master.

from time import perf_counter as clock

import dask
import dask.array as da
from distributed import Client

A = 200  # 9_000
B = 13
C = 1000  # 4320
D = 1000  # 4320

def main():
    client = Client(n_workers=8, threads_per_worker=1)  # noqa
    x = da.random.random((A, B, C, D), chunks=(1, 1, C, D))
    y = (5*x**2 + 1).mean()
    t0 = clock()
    y.compute()
    t1 = clock()

    print('{}: {:0.2f}'.format(dask.__version__, t1 - t0))

if __name__ == '__main__':
    main()

I get the following times:

Dask Version Distributed Version Time
1.0.0 1.27.0 25.93
1.2.0+21.g20 1.27.0+0.g8c07c878 21.71

So the post-HLG version is slightly faster. I had to scale the problem down for my laptop, so I wouldn't expect a large improvement. Hopefully the smaller scale didn't fundamentally change the performance characteristics.

I'm trying out the intake-based example now.

rabernat commented 5 years ago

Thanks for looking into this @TomAugspurger! (Shouldn't you be changing diapers, not hacking dask! 😉)

I will try to find time later today to run this test again with your sample code.

TomAugspurger commented 5 years ago

Shouldn't you be changing diapers

The sleep schedules occasionally align :)

I will try to find time later today to run this test again with your sample code.

Thanks, no rush though since I likely won't be able to put together a fix for a week or two.

TomAugspurger commented 5 years ago

I may be missing something, but isn't the main issue here

However, the problem is that .persist() doesn't seem to reduce the number of tasks in the scheduler's memory.

already how persist works with distributed?

import dask
import dask.array as da
import numpy as np

N = 1_000
K = 250

X = da.random.random((N, K), chunks=(K, K))

Y = np.sin((X + 1) ** 2).dot(X.T)
Y.visualize(filename="before", rankdir='LR')

from distributed import Client
client = Client()

Y2 = Y.persist()
Y2.visualize(filename="after", rankdir="LR")

before after

>>> len(Y.dask), len(Y2.dask)
(68, 16)

So does changing the line

intermediate_result.persist(forget=True)

to

intermediate_result = intermediate_result.persist()

solve this?

TomAugspurger commented 5 years ago

@rabernat when you get a chance, can you check on https://github.com/dask/dask/issues/4630#issuecomment-487719304? Specifically, in your pseudocode, does changing

intermediate_result.persist(forget=True)
...

to

intermediate_result = intermediate_result.persist()

fix things? https://github.com/dask/dask/issues/4630#issuecomment-487719304 shows that after the persist we end up with one task per chunk.

jakirkham commented 4 years ago

@rabernat, did you have a chance to try Tom's suggestion?