dask / community

For general discussion and community planning. Discussion issues welcome.
19 stars 3 forks source link

Post-mortem: why an easy workflow was horribly non-performant, and what we could do to make it easier for users to write fast dask code #301

Open crusaderky opened 1 year ago

crusaderky commented 1 year ago

Executive summary

Today, the user experience of a typical novice to intermediate dask.dataframe user can be very poor. Building a workflow that is supposedly very straightforward can result in an extremely non-performant output with a tendency to randomly kill off workers. At the end of this post you'll find 13 remedial actions, 10 of which can be sensibly achieved in a few weeks, which can drastically improve the user experience.

Introduction

I recently went through a demo notebook, written by a data scientist, whose purpose is to showcase dask.dataframe to new dask users through a real-life use case. The notebook does what I would call light data preprocessing on a 40 GiB parquet dataset of NYC taxis, with the purpose of later feeding them into a machine learning algorithm.

The first time I ran it, the notebook ran in 25 minutes and required hosts mounting a bare minimum of 32 GiB RAM each. After a long day of tweaking it, I brought it down to 2 minutes runtime and 8 GiB per host RAM requirements.

The problem is this: the workflow implemented by the notebook is not rocket science. Getting to a performant implementation is something you would expect to be a stress-free exercise for an average data scientist; instead it took a day's worth of anger-debugging from a dask maintainer to make it work properly.

This thread is a post-mortem of my experience with it, detailing all the roadblocks that both the original coder and myself hit, with proposals on how to make it a smooth sail in the future.

The algorithm

What's implemented is a standard machine learning pre-processing flow:

  1. Load a 16 billion rows parquet dataset from s3
  2. Discard unneeded columns
  3. Discard rows containing malformed data
  4. Discard rows containing outliers
  5. Join with a tiny pandas.Dataframe (265 rows) containing domain mapping
  6. Convert all domain-based columns to categories, with global domains
  7. Write to parquet on s3

Implementation

Data loading and column manipulation

This first part loads up the dataframe, generates a few extra columns as a function of other columns, and drops unnecessary columns. The dataset is publicly accessible - you may reproduce this on your own.

Original code

client = distributed.Client(...)
ddf = dd.read_parquet(
    "s3://coiled-datasets/prefect-dask/nyc-uber-lyft/processed_data.parquet"
)
print(f"size of the total dataset is:  {len(ddf.index)}")

ddf = ddf.assign(accessible_vehicle=1)
ddf.accessible_vehicle = ddf.accessible_vehicle.where(ddf.on_scene_datetime.isnull(), 0)
ddf = ddf.assign(pickup_month=ddf.pickup_datetime.dt.month)
ddf = ddf.assign(pickup_dow=ddf.pickup_datetime.dt.dayofweek)
ddf = ddf.assign(pickup_hour=ddf.pickup_datetime.dt.hour)

ddf = ddf.drop(
    columns=[
        "on_scene_datetime",
        "request_datetime",
        "pickup_datetime",
        "dispatching_base_num",
        "originating_base_num",
        "shared_request_flag",
        "shared_match_flag",
        "dropoff_datetime",
        "base_passenger_fare",
        "bcf",
        "sales_tax",
        "tips",
        "driver_pay",
        "access_a_ride_flag",
        "wav_match_flag",
        "wav_request_flag",
    ]
)
ddf = ddf.reset_index(drop=True)

ddf["airport_fee"] = ddf["airport_fee"].replace("None", 0)
ddf["airport_fee"] = ddf["airport_fee"].replace("nan", 0)
ddf["airport_fee"] = ddf["airport_fee"].astype(float)
ddf["airport_fee"] = ddf["airport_fee"].fillna(0)

To an intermediate user's eye, this looks OK. But it is very, very bad:

  1. All string data is read in Python object format, which is excruciatingly slow to process. Switching to PyArrow was not painless: https://github.com/dask/dask/issues/9840.
  2. That innocent-looking print statement on the second line reads the whole dataset into memory and then discards it.
  3. The code reads the whole thing in memory and then drop unneeded columns. However, parquet allows to efficiently cherry-pick individual columns, while leaving the rest untouched on disk.
  4. Last but not least: the partitions on disk are very dishomogeneous, with the smallest being 22 MiB and the largest weighting a whopping 836 MiB. This is what caused the memory requirements of 32 GiB per host. However, if you repartition() them into smaller chunks, the whole thing becomes a lot more manageable, even if the initial load still requires to compute everything at once.

Revised code

Before starting the cluster: (this is Coiled-specific. Other clusters will require you to manually set the config on all workers).

dask.config.set({"dataframe.dtype_backend": "pyarrow"})
client = distributed.Client(...)

# Workaround to https://github.com/dask/dask/issues/9840
from distributed import WorkerPlugin
class SetPandasOptions(WorkerPlugin):
    def setup(self, worker):
        pd.set_option("string_storage", "pyarrow")
pd.set_option("string_storage", "pyarrow")  # Set on the client
_ = client.register_worker_plugin(SetPandasOptions())  # Set on the workers
# End workaround

ddf = dd.read_parquet(
    "s3://coiled-datasets/prefect-dask/nyc-uber-lyft/processed_data.parquet",
    index=False,
    columns=[
        "hvfhs_license_num",
        "PULocationID",
        "DOLocationID",
        "trip_miles",
        "trip_time",
        "tolls",
        "congestion_surcharge",
        "airport_fee",
        "wav_request_flag",
        "on_scene_datetime",
        "pickup_datetime",
    ],
)
ddf = ddf.repartition(partition_size="100MB")
ddf = ddf.assign(
    accessible_vehicle=ddf.on_scene_datetime.isnull(),
    pickup_month=ddf.pickup_datetime.dt.month,
    pickup_dow=ddf.pickup_datetime.dt.dayofweek,
    pickup_hour=ddf.pickup_datetime.dt.hour,
)
ddf = ddf.drop(columns=["on_scene_datetime", "pickup_datetime"])
ddf["airport_fee"] = ddf["airport_fee"].replace("None", 0).astype(float).fillna(0)

Note that the call to repartition reads the whole thing in memory and then discards it. This takes a substantial amount of time, but it's the best I could do. It is wholly avoidable though. I also feel that a novice user should not be bothered with having to deal with oversized chunks themselves:

When repartition() is unavoidable, because it's in the middle of a computation, it could avoid computing everything:

As for PyArrow strings: I am strongly convinced they should be the default. I understand that setting PyArrow strings on by default would cause dask to deviate from pandas. I think it's a case where the deviation is worthwhile - pandas doesn't need to cope with the GIL and serialization!

Being forced to think ahead with column slicing is another interesting discussion. It could be avoided if each column was a separate dask key. For the record, this is exactly what xarray.Dataset does. Alternatively, High level expressions (dask#7933) would allow rewriting the dask graph on the fly. With either of these changes, the columns you don't need would never leave the disk (as long as you set chunk_size=1). Also, it would mean that len(ddf) would have to load a single column (seconds) instead of the whole thing (minutes).

I appreciate that introducing splitting by column in dask.dataframe would be a very major effort - but I think it's very likely worth the price.

A much cheaper fix to len():

Drop rows

After column manipulation, we move to row filtering:

Original code

ddf = ddf.dropna(how="any")

original_rowcount = len(ddf.index)

# Remove outliers
lower_bound = 0
Q3 = ddf["trip_time"].quantile(0.75)
print(f"Q3 is:  {Q3.compute()}")
upper_bound = Q3 + (1.5 * (Q3 - lower_bound))
print(f"Upper bound is:  {upper_bound.compute()}")

ddf = ddf.loc[(ddf["trip_time"] >= lower_bound) & (ddf["trip_time"] <= upper_bound)]

ddf = ddf.repartition(partition_size="100MB").persist()
print(
    "Fraction of dataset left after removing outliers:",
    len(ddf.index) / original_rowcount,
)

This snippet recomputes everything so far (load from s3 AND column preprocessing), :no_mouth: FIVE :no_mouth: TIMES :no_mouth::

  1. It performs yet another call to len(ddf.index).
  2. It computes the Q3 on a column
  3. It computes upper_bound

Again, if the graph was split by columns or was rewritten on the fly by high level expressions, these three would be much less of a problem.

  1. repartition(partition_size=...) under the hood calls compute() and then discards everything.
  2. persist() recomputes everything from the beginning one more time.

The one thing that is inexpensive is the final call to len(ddf.index), because it's immediately after a persist().

Revised code

ddf = ddf.dropna(how="any")

# Remove outliers
lower_bound = 0
Q3 = ddf['trip_time'].quantile(0.75)
upper_bound = Q3 + (1.5 * (Q3 - lower_bound))
ddf = ddf.loc[(ddf["trip_time"] >= lower_bound) & (ddf["trip_time"] <= upper_bound)]

The repartition() call is no longer there, since I already called it once and now partitions are guaranteed to be smaller or equal to before.

Like before, I've outright removed the print() statements. If I had to retain them, I would push them further down, immediately after a call to persist(), so that the computation is only done once:

original_rowcount = ddf.size()
# ... drop rows
new_rowcount = ddf.size()
# ...
ddf = ddf.persist()
original_rowcount, new_rowcount = client.compute(original_rowcount, new_rowcount)
print(
    "Fraction of dataset left after removing outliers:",
    new_rowcount / original_rowcount
)

Joins

Download the "Taxi Zone Lookup Table (CSV) from https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page, and save it to data/taxi+_zone_lookup.csv.

Original code

taxi_df = pd.read_csv("data/taxi+_zone_lookup.csv", usecols=["LocationID", "Borough"])

ddf = dd.merge(ddf, taxi_df, left_on="PULocationID", right_on="LocationID", how="inner")
ddf = ddf.rename(columns={"Borough": "PUBorough"})
ddf = ddf.drop(columns="LocationID")

ddf = dd.merge(ddf, taxi_df, left_on="DOLocationID", right_on="LocationID", how="inner")
ddf = ddf.rename(columns={"Borough": "DOBorough"})
ddf = ddf.drop(columns="LocationID")

BOROUGH_MAPPING = {
    "Manhattan": "Superborough 1",
    "Bronx": "Superborough 1",
    "EWR": "Superborough 1",
    "Brooklyn": "Superborough 2",
    "Queens": "Superborough 2",
    "Staten Island": "Superborough 3",
    "Unknown": "Unknown",
}

def make_cross_borough_cat(df):
    PUSuperborough = [BOROUGH_MAPPING.get(i) for i in df.PUBorough.tolist()]
    DOSuperborough = [BOROUGH_MAPPING.get(i) for i in df.DOBorough.tolist()]
    PUSuperborough_DOSuperborough_Pair = [
        f"{i}-{j}" for i, j in zip(PUSuperborough, DOSuperborough)
    ]
    return df.assign(PUSuperborough_DOSuperborough=PUSuperborough_DOSuperborough_Pair)

ddf = ddf.map_partitions(lambda df: make_cross_borough_cat(df))

Again, this sneaks in more pure-python strings into the dataframe. This would be solved by force-casting them to string[pyarrow] in from_pandas() (which is called under the hood by merge()):

It also does a typical mistake of using pure-python code for the sake of simplicity instead of going through a pandas join. While this is already a nonperformant in pandas, when you move to dask you have the problem that it hogs the GIL.

The sensible approach to limiting this problem would be to show the user a plot of how much time they spent on each task:

In addition to the above ticket, we need documentation & evangelism to teach the users that to debug a non-performant workflow they can look at Prometheus, and at which metrics they should look at first. At the moment, Prometheus is a single, hard-to-notice page in the dask docs.

The second join is performed on a much bigger dataset than necessary. There's no real solution to this - this is an abstract algorithmic optimization that the developer could have noticed themselves.

Finally, the PUSuperborough and DOSuperborough columns could be dropped. I know they're unnecessary by reading the next section below about categorization. Again, nothing we (dask devs) can do here.

Revised code

taxi_zone_lookup = pd.read_csv(
    "data/taxi+_zone_lookup.csv", usecols=["LocationID", "Borough"]
)
BOROUGH_MAPPING = {
    "Manhattan": "Superborough 1",
    "Bronx": "Superborough 1",
    "EWR": "Superborough 1",
    "Brooklyn": "Superborough 2",
    "Queens": "Superborough 2",
    "Staten Island": "Superborough 3",
    "Unknown": "Unknown",
}

taxi_zone_lookup["Superborough"] = [
    BOROUGH_MAPPING[k] for k in taxi_zone_lookup["Borough"]
]
taxi_zone_lookup = taxi_zone_lookup.astype(
    {"Borough": "string[pyarrow]", "Superborough": "string[pyarrow]"}
)

ddf = dd.merge(ddf, taxi_zone_lookup, left_on="PULocationID", right_on="LocationID", how="inner")
ddf = ddf.rename(columns={"Borough": "PUBorough", "Superborough": "PUSuperborough"})
ddf = ddf.drop(columns="LocationID")

ddf = dd.merge(ddf, taxi_zone_lookup, left_on="DOLocationID", right_on="LocationID", how="inner")
ddf = ddf.rename(columns={"Borough": "DOBorough", "Superborough": "DOSuperborough"})
ddf = ddf.drop(columns="LocationID")

ddf["PUSuperborough_DOSuperborough"] = ddf.PUSuperborough.str.cat(
    ddf.DOSuperborough, sep="-"
)
ddf = ddf.drop(columns=["PUSuperborough", "DOSuperborough"])

Categorization

This dataset is going to be fed into a machine learning engine, so everything that can be converted into a domain, should:

ddf = ddf.repartition(partition_size="100MB").persist()

categories = [
    "hvfhs_license_num",
    "PULocationID",
    "DOLocationID",
    "accessible_vehicle",
    "pickup_month",
    "pickup_dow",
    "pickup_hour",
    "PUBorough",
    "DOBorough",
    "PUSuperborough_DOSuperborough",
]
ddf[categories] = ddf[categories].astype("category")
ddf = ddf.categorize(columns=categories)

The first call to repartition() computes everything since the previous call to persist() - this includes the pure-python joins - and then discards it. Then, persist() computes it again.

The call to astype("category") is unnecessary. categorize(), while in this case is OK since it's (accidentally?) just after a persist(), is a major pain point on its own:

Revised code

categories = [
    "hvfhs_license_num",
    "PULocationID",
    "DOLocationID",
    "wav_request_flag",
    "accessible_vehicle",
    "pickup_month",
    "pickup_dow",
    "pickup_hour",
    "PUBorough",
    "DOBorough",
    "PUSuperborough_DOSuperborough",
]

# Read https://github.com/dask/dask/issues/9847
ddf = ddf.astype(dict.fromkeys(categories, "category"))
ddf = ddf.persist()
ddf = ddf.categorize(categories)

Disk write

Original code

    ddf = ddf.repartition(partition_size="100MB")
    ddf.to_parquet(
        "s3://coiled-datasets/prefect-dask/nyc-uber-lyft/feature_table.parquet",
        overwrite=True,
    )

The final call to repartition() recomputes everything since persist() and then discards it - in this case, not that much.

Revised code

ddf = ddf.persist().repartition(partition_size="100MB")

# Workaround to https://github.com/apache/arrow/issues/33727
ddf = ddf.astype(
    {
        col: pd.CategoricalDtype(dt.categories.astype(object))
        for col, dt in ddf.dtypes.items()
        if isinstance(dt, pd.CategoricalDtype)
        and dt.categories.dtype == "string[pyarrow]"
    }
)

ddf.to_parquet(
    "s3://coiled-datasets/prefect-dask/nyc-uber-lyft/feature_table.parquet",
    overwrite=True,
)

Action points recap

Low effort

Intermediate effort

mrocklin commented 1 year ago

@crusaderky this is great. Maybe a good blogpost?

gjoseph92 commented 1 year ago

@crusaderky this is an excellent and very valuable writeup. Thanks for taking the time to do this.

Being forced to think ahead with column slicing is another interesting discussion. It could be avoided if each column was a separate dask key

We've also talked about doing column pruning automatically. This is one of the optimizations we hope would be enabled by high level expressions / revising high level graphs https://github.com/dask/dask/issues/7933 cc @rjzamora. Might be worth adding that to the list.

crusaderky commented 1 year ago

We've also talked about doing column pruning automatically. This is one of the optimizations we hope would be enabled by high level expressions / revising high level graphs dask/dask#7933 cc @rjzamora. Might be worth adding that to the list.

Updated post.

martindurant commented 1 year ago

We have mentioned this before, but dask-awkward has a nice example of following the metadata through many layers to be able to prune column loads. It's enabled by the rich awkward "typetracer" (a no-data array) that they implemented specifically for this purpose and just works through many operations where fake data or zero-length array might not.

gjoseph92 commented 1 year ago

I'm curious about a couple human / API design aspects of this that I think are also worth looking into.

Two things we see a lot of, that don't need to be there:

  1. repeated re-computation
  2. persist

There are 4 unnecessary recomputations for print statements, and 3 recomputations for repartition operations.

The print statements I find really interesting. I'm going to assume the author didn't realize how much time these added. (As in, they weren't so important for the workflow that they were worth leaving in.) Just removing those, without any of the other changes or fixes you've suggested, would still speed up the code a ton!

So how could we make it easier to recognize how expensive this is?

  1. We could track recently-released keys on the scheduler. Using some heuristic, if you submit a graph that contains a lot of recently-released keys, we send a warning message to the client that you may be recomputing data, and suggesting changes to make (use persist, don't call compute multiple times, etc.). I'm skeptical how useful/effective this would be.
  2. In general, it's hard to map what dask is doing back to user code. When dask is running a task, you don't know what line number in your code that task came from. This can make debugging dask performance feel pretty impenetrable. Now obviously here, running any Python profiler on the script would highlight these compute calls.

    But perhaps because it's so impenetrable, people don't think to even try to understand or profile performance? (Not to mention that many dask users might not be familiar with profiling in the first place.) So the more we make it easier to understand performance, the more users will feel empowered to think about it and tweak things themselves. I imagine that right now, a lot of users call compute and hope for the best. How the time is spent and why during that compute feels like a black box (or a bunch of colors flying around on a dashboard that are hard to understand).

    I wonder if we could make something (on the dashboard?) that gives a profile of your own code and how much dask work each line does. I'm imagining maybe line-profiler style, like the scalene GUI? We wouldn't actually run a profiler on the client—we'd basically have a symbol table mapping from dask keys to lines in user code, then show runtime / memory? / task count / transfers / spill / etc aggregated over tasks for that line. We'd also have a way of splitting it up by compute call, which would highlight repeated computations. (Note that a natural place to put this symbol table would also be https://github.com/dask/dask/issues/7933).

    This is a big thing that would be very generally useful far beyond unnecessary print statements. But if we had it, it would probably help you find the unnecessary print statements.

Another thought: print statements are useful. Even in a world where the author recognized these repeated computes were adding a lot of time, they still might want to print out Q3 and original_rowcount just to see them. Maybe we could have a dask.log or dask.print function to make it easier to log delayed values without calling compute?

Finally, persist seems to be sprinkled liberally across the original code. To me, persist is a bad idea in most cases because it means you're now manually managing memory: dask can't stream operations in constant memory, and release chunks that are no longer needed, because you've pinned them all. But I feel like I see many users add persists all over their code, and I want to know why. Is there something in the docs? It feels like a knob people reach for when they're not happy with performance (because performance is a black box, see above), but why?

rjzamora commented 1 year ago

It feels like a knob people reach for when they're not happy with performance (because performance is a black box, see above), but why?

I agree that using persist can be problematic in many cases, but is also something that the documentation recommends: https://docs.dask.org/en/stable/best-practices.html#persist-when-you-can

jakirkham commented 1 year ago

This is a really nice write-up!

What do people here think about caching? Agree when computations happen could be surprising to an end user. At a minimum, workflows could avoid repeating a computation unnecessarily.

Frequently point users (even experienced developers) to graphchain. Perhaps Dask would benefit by baking this in.

cc @lsorber (in case you have thoughts on any of this)

GenevieveBuckley commented 1 year ago

@crusaderky this is great. Maybe a good blogpost?

It'd also make a pretty fascinating short talk for one of the dask demo days

crusaderky commented 1 year ago

About the pure-python code: @gjoseph92 I wonder if it would be feasible to offer a simple flag in coiled.Cluster, e.g. low_level_profile=True, which would let us populate a grafana plot of time spent waiting in GIL contention, broken down by user function (the lowest level callable directly visible in the dask graph? Of course enabling such a flag would come at a ~2x performance cost. I don't know if we're capable of real-time C-level profiling though.