Open crusaderky opened 1 year ago
@crusaderky this is great. Maybe a good blogpost?
@crusaderky this is an excellent and very valuable writeup. Thanks for taking the time to do this.
Being forced to think ahead with column slicing is another interesting discussion. It could be avoided if each column was a separate dask key
We've also talked about doing column pruning automatically. This is one of the optimizations we hope would be enabled by high level expressions / revising high level graphs https://github.com/dask/dask/issues/7933 cc @rjzamora. Might be worth adding that to the list.
We've also talked about doing column pruning automatically. This is one of the optimizations we hope would be enabled by high level expressions / revising high level graphs dask/dask#7933 cc @rjzamora. Might be worth adding that to the list.
Updated post.
We have mentioned this before, but dask-awkward has a nice example of following the metadata through many layers to be able to prune column loads. It's enabled by the rich awkward "typetracer" (a no-data array) that they implemented specifically for this purpose and just works through many operations where fake data or zero-length array might not.
I'm curious about a couple human / API design aspects of this that I think are also worth looking into.
Two things we see a lot of, that don't need to be there:
persist
There are 4 unnecessary recomputations for print
statements, and 3 recomputations for repartition
operations.
The print
statements I find really interesting. I'm going to assume the author didn't realize how much time these added. (As in, they weren't so important for the workflow that they were worth leaving in.) Just removing those, without any of the other changes or fixes you've suggested, would still speed up the code a ton!
So how could we make it easier to recognize how expensive this is?
persist
, don't call compute
multiple times, etc.). I'm skeptical how useful/effective this would be.In general, it's hard to map what dask is doing back to user code. When dask is running a task, you don't know what line number in your code that task came from. This can make debugging dask performance feel pretty impenetrable. Now obviously here, running any Python profiler on the script would highlight these compute
calls.
But perhaps because it's so impenetrable, people don't think to even try to understand or profile performance? (Not to mention that many dask users might not be familiar with profiling in the first place.) So the more we make it easier to understand performance, the more users will feel empowered to think about it and tweak things themselves. I imagine that right now, a lot of users call compute
and hope for the best. How the time is spent and why during that compute
feels like a black box (or a bunch of colors flying around on a dashboard that are hard to understand).
I wonder if we could make something (on the dashboard?) that gives a profile of your own code and how much dask work each line does. I'm imagining maybe line-profiler style, like the scalene GUI? We wouldn't actually run a profiler on the client—we'd basically have a symbol table mapping from dask keys to lines in user code, then show runtime / memory? / task count / transfers / spill / etc aggregated over tasks for that line. We'd also have a way of splitting it up by compute
call, which would highlight repeated computations. (Note that a natural place to put this symbol table would also be https://github.com/dask/dask/issues/7933).
This is a big thing that would be very generally useful far beyond unnecessary print
statements. But if we had it, it would probably help you find the unnecessary print
statements.
Another thought: print
statements are useful. Even in a world where the author recognized these repeated computes were adding a lot of time, they still might want to print out Q3
and original_rowcount
just to see them. Maybe we could have a dask.log
or dask.print
function to make it easier to log delayed values without calling compute
?
Finally, persist
seems to be sprinkled liberally across the original code. To me, persist
is a bad idea in most cases because it means you're now manually managing memory: dask can't stream operations in constant memory, and release chunks that are no longer needed, because you've pinned them all. But I feel like I see many users add persist
s all over their code, and I want to know why. Is there something in the docs? It feels like a knob people reach for when they're not happy with performance (because performance is a black box, see above), but why?
It feels like a knob people reach for when they're not happy with performance (because performance is a black box, see above), but why?
I agree that using persist
can be problematic in many cases, but is also something that the documentation recommends: https://docs.dask.org/en/stable/best-practices.html#persist-when-you-can
This is a really nice write-up!
What do people here think about caching? Agree when computations happen could be surprising to an end user. At a minimum, workflows could avoid repeating a computation unnecessarily.
Frequently point users (even experienced developers) to graphchain. Perhaps Dask would benefit by baking this in.
cc @lsorber (in case you have thoughts on any of this)
@crusaderky this is great. Maybe a good blogpost?
It'd also make a pretty fascinating short talk for one of the dask demo days
About the pure-python code: @gjoseph92 I wonder if it would be feasible to offer a simple flag in coiled.Cluster, e.g. low_level_profile=True
, which would let us populate a grafana plot of time spent waiting in GIL contention, broken down by user function (the lowest level callable directly visible in the dask graph? Of course enabling such a flag would come at a ~2x performance cost. I don't know if we're capable of real-time C-level profiling though.
Executive summary
Today, the user experience of a typical novice to intermediate dask.dataframe user can be very poor. Building a workflow that is supposedly very straightforward can result in an extremely non-performant output with a tendency to randomly kill off workers. At the end of this post you'll find 13 remedial actions, 10 of which can be sensibly achieved in a few weeks, which can drastically improve the user experience.
Introduction
I recently went through a demo notebook, written by a data scientist, whose purpose is to showcase
dask.dataframe
to new dask users through a real-life use case. The notebook does what I would call light data preprocessing on a 40 GiB parquet dataset of NYC taxis, with the purpose of later feeding them into a machine learning algorithm.The first time I ran it, the notebook ran in 25 minutes and required hosts mounting a bare minimum of 32 GiB RAM each. After a long day of tweaking it, I brought it down to 2 minutes runtime and 8 GiB per host RAM requirements.
The problem is this: the workflow implemented by the notebook is not rocket science. Getting to a performant implementation is something you would expect to be a stress-free exercise for an average data scientist; instead it took a day's worth of anger-debugging from a dask maintainer to make it work properly.
This thread is a post-mortem of my experience with it, detailing all the roadblocks that both the original coder and myself hit, with proposals on how to make it a smooth sail in the future.
The algorithm
What's implemented is a standard machine learning pre-processing flow:
Implementation
Data loading and column manipulation
This first part loads up the dataframe, generates a few extra columns as a function of other columns, and drops unnecessary columns. The dataset is publicly accessible - you may reproduce this on your own.
Original code
To an intermediate user's eye, this looks OK. But it is very, very bad:
print
statement on the second line reads the whole dataset into memory and then discards it.repartition()
them into smaller chunks, the whole thing becomes a lot more manageable, even if the initial load still requires to compute everything at once.Revised code
Before starting the cluster: (this is Coiled-specific. Other clusters will require you to manually set the config on all workers).
Note that the call to
repartition
reads the whole thing in memory and then discards it. This takes a substantial amount of time, but it's the best I could do. It is wholly avoidable though. I also feel that a novice user should not be bothered with having to deal with oversized chunks themselves:When repartition() is unavoidable, because it's in the middle of a computation, it could avoid computing everything:
As for PyArrow strings: I am strongly convinced they should be the default. I understand that setting PyArrow strings on by default would cause dask to deviate from pandas. I think it's a case where the deviation is worthwhile - pandas doesn't need to cope with the GIL and serialization!
Being forced to think ahead with column slicing is another interesting discussion. It could be avoided if each column was a separate dask key. For the record, this is exactly what
xarray.Dataset
does. Alternatively, High level expressions (dask#7933) would allow rewriting the dask graph on the fly. With either of these changes, the columns you don't need would never leave the disk (as long as you set chunk_size=1). Also, it would mean thatlen(ddf)
would have to load a single column (seconds) instead of the whole thing (minutes).I appreciate that introducing splitting by column in dask.dataframe would be a very major effort - but I think it's very likely worth the price.
A much cheaper fix to
len()
:Drop rows
After column manipulation, we move to row filtering:
Original code
This snippet recomputes everything so far (load from s3 AND column preprocessing), :no_mouth: FIVE :no_mouth: TIMES :no_mouth::
len(ddf.index)
.Again, if the graph was split by columns or was rewritten on the fly by high level expressions, these three would be much less of a problem.
repartition(partition_size=...)
under the hood callscompute()
and then discards everything.persist()
recomputes everything from the beginning one more time.The one thing that is inexpensive is the final call to
len(ddf.index)
, because it's immediately after a persist().Revised code
The repartition() call is no longer there, since I already called it once and now partitions are guaranteed to be smaller or equal to before.
Like before, I've outright removed the
print()
statements. If I had to retain them, I would push them further down, immediately after a call to persist(), so that the computation is only done once:Joins
Download the "Taxi Zone Lookup Table (CSV) from https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page, and save it to
data/taxi+_zone_lookup.csv
.Original code
Again, this sneaks in more pure-python strings into the dataframe. This would be solved by force-casting them to
string[pyarrow]
infrom_pandas()
(which is called under the hood bymerge()
):It also does a typical mistake of using pure-python code for the sake of simplicity instead of going through a pandas join. While this is already a nonperformant in pandas, when you move to dask you have the problem that it hogs the GIL.
The sensible approach to limiting this problem would be to show the user a plot of how much time they spent on each task:
In addition to the above ticket, we need documentation & evangelism to teach the users that to debug a non-performant workflow they can look at Prometheus, and at which metrics they should look at first. At the moment, Prometheus is a single, hard-to-notice page in the dask docs.
The second join is performed on a much bigger dataset than necessary. There's no real solution to this - this is an abstract algorithmic optimization that the developer could have noticed themselves.
Finally, the PUSuperborough and DOSuperborough columns could be dropped. I know they're unnecessary by reading the next section below about categorization. Again, nothing we (dask devs) can do here.
Revised code
Categorization
This dataset is going to be fed into a machine learning engine, so everything that can be converted into a domain, should:
The first call to
repartition()
computes everything since the previous call topersist()
- this includes the pure-python joins - and then discards it. Then,persist()
computes it again.The call to
astype("category")
is unnecessary.categorize()
, while in this case is OK since it's (accidentally?) just after apersist()
, is a major pain point on its own:Revised code
Disk write
Original code
The final call to
repartition()
recomputes everything sincepersist()
and then discards it - in this case, not that much.Revised code
Action points recap
Low effort
Intermediate effort
High effort
Very high effort
dask.Dataframe
by column