rapidsai / cudf

cuDF - GPU DataFrame Library
https://docs.rapids.ai/api/cudf/stable/
Apache License 2.0
8.4k stars 898 forks source link

[BUG] TDIGEST_MERGE group by aggregation scales very badly #16625

Closed revans2 closed 1 month ago

revans2 commented 2 months ago

Describe the bug The current implementation of TDIGEST_MERGE when used in a group by context launches separate GPU operations (kernels/memory copies) on the order of the number of groups in the output aggregation.

https://github.com/rapidsai/cudf/blob/58799d698d861866b5650d368f5195174fc9644e/cpp/src/quantiles/tdigest/tdigest_aggregation.cu#L1024-L1033

Specifically the part that is per group is at

https://github.com/rapidsai/cudf/blob/58799d698d861866b5650d368f5195174fc9644e/cpp/src/quantiles/tdigest/tdigest_aggregation.cu#L1055-L1086

If I run a spark query like.

spark.time(spark.range(0, 1000000L, 1, 2).selectExpr("id % 500000 as k", "id").groupBy("k").agg(percentile_approx(col("id"), lit(0.95), lit(100000))).write...)

I can see the merge operator taking a massive amount of time and launching 500,000 kernels to merge the compacted items in the digest. We could skip all of this if we just had a segmented merge, but we do have a segmented sort, which is probably good enough, with how long the current code takes to run.

https://github.com/rapidsai/cudf/blob/58799d698d861866b5650d368f5195174fc9644e/cpp/include/cudf/sorting.hpp#L258-L264

Steps/Code to reproduce bug

spark.time(spark.range(0, 1000000L, 1, 2).selectExpr("id % 500000 as k", "id").groupBy("k").agg(percentile_approx(col("id"), lit(0.95), lit(100000))).orderBy("k").show())

Sorry that this is for Spark, but it can be replicated in C++. You just want to do a group by aggregation for TDIGEST_MERGE where there are a large number of output groups and a few items to actually mergre for each group. It is still bad if all of the values are unique, but then it does not launch any kernels to do the merge. It just does a few memcpy calls.

Expected behavior The GPU should destroy the CPU like it does for a reduction.

spark.time(spark.range(0, 50000000L, 1, 2).select(percentile_approx(col("id"), lit(0.95), lit(100000))).show())
nvdbaranec commented 2 months ago

I don't even think this needs to be a segmented sort(thrust doesn't have one in any case. The sort-by-key functionality is a bit of a misnomer). But, I think if we just glommed everything into one big array and did a regular sort, we'd get the same effect. Off the top of my head, I don't know if radix sort (what thrust uses internally) has any horrible performance problems when being handed already-almost-sorted (or in this case, already-sorted) inputs. Probably not.