NVIDIA / spark-rapids

Spark RAPIDS plugin - accelerate Apache Spark with GPUs
https://nvidia.github.io/spark-rapids
Apache License 2.0
824 stars 236 forks source link

[FEA] Explore `CollectLimit` use in DPP queries on Databricks #11764

Open mythrocks opened 1 day ago

mythrocks commented 1 day ago

This is a follow-on issue from the work in #11750, which fixes the test failures in aqe_test.py pertaining to DPP on Databricks 14.3.

We found a condition where Databricks seems to insert a CollectLimit (with a very large collection limit) on a join query, with DPP enabled.

Here's the query (paraphrased from aqe_test.py:: test_aqe_join_with_dpp):

with tmp as 
(SELECT
   site_id, day, test_id
 FROM infoA
 UNION ALL
 SELECT
   CASE
     WHEN site_id = 'LONG_SITE_NAME_0' then 'site_0'
     WHEN site_id = 'LONG_SITE_NAME_1' then 'site_1'
     ELSE site_id
   END AS site_id, day, test_spec AS test_id
 FROM infoB)
 SELECT *
 FROM tmp a JOIN tests b ON a.test_id = b.test_id AND a.site_id = b.site_id
 WHERE day = '1990-01-01'
 AND a.site_id IN ('site_0', 'site_1')

The resultant plan looks thus:

+- Project [toprettystring(site_id#6, Some(UTC)) AS toprettystring(site_id)#42, toprettystring(day#5, Some(UTC)) AS toprettystring(day)#43, toprettystring(test_id#7, Some(UTC)) AS toprettystring(test_id)#44, toprettystring(test_id#0, Some(UTC)) AS toprettystring(test_id)#45, toprettystring(site_id#1, Some(UTC)) AS toprettystring(site_id)#46]
   +- BroadcastHashJoin [test_id#7, site_id#6], [test_id#0, site_id#1], Inner, BuildRight, false, true
      :- Union
      :  :- Project [site_id#6, day#5, test_id#7]
      :  :  +- FileScan parquet [day#5,site_id#6,test_id#7] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/tmp/myth/PARQUET_DATA/infoA], PartitionFilters: [isnotnull(day#5), (day#5 = 1990-01-01), site_id#6 IN (site_0,site_1), isnotnull(test_id#7), isno..., PushedFilters: [], ReadSchema: struct<>
      :  :        :- SubqueryBroadcast dynamicpruning#53, [0], [test_id#0], true, [id=#322]
      :  :        :  +- AdaptiveSparkPlan isFinalPlan=false
      :  :        :     +- CollectLimit 3000001
      :  :        :        +- HashAggregate(keys=[test_id#0], functions=[], output=[test_id#0])
      :  :        :           +- ShuffleQueryStage 2, Statistics(sizeInBytes=942.0 B, rowCount=34, ColumnStat: N/A, isRuntime=true)
      :  :        :              +- ReusedExchange [test_id#0, site_id#1], GpuColumnarExchange gpusinglepartitioning$(), EXECUTOR_BROADCAST, [plan_id=238], [loreId=17]
      :  :        +- SubqueryBroadcast dynamicpruning#54, [0], [site_id#1], true, [id=#335]
      :  :           +- AdaptiveSparkPlan isFinalPlan=false
      :  :              +- CollectLimit 3000001
      :  :                 +- HashAggregate(keys=[site_id#1], functions=[], output=[site_id#1])
      :  :                    +- ShuffleQueryStage 4, Statistics(sizeInBytes=942.0 B, rowCount=34, ColumnStat: N/A, isRuntime=true)
      :  :                       +- ReusedExchange [test_id#0, site_id#1], GpuColumnarExchange gpusinglepartitioning$(), EXECUTOR_BROADCAST, [plan_id=238], [loreId=17]
      :  +- Project [CASE WHEN (site_id#15 = LONG_SITE_NAME_0) THEN site_0 WHEN (site_id#15 = LONG_SITE_NAME_1) THEN site_1 ELSE site_id#15 END AS site_id#30, day#14, test_spec#12 AS test_id#31]
      :     +- Filter isnotnull(test_spec#12)
      :        +- FileScan parquet [test_spec#12,day#14,site_id#15] Batched: true, DataFilters: [isnotnull(test_spec#12), dynamicpruningexpression(test_spec#12 IN dynamicpruning#53)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/tmp/myth/PARQUET_DATA/infoB], PartitionFilters: [isnotnull(day#14), (day#14 = 1990-01-01), CASE WHEN (site_id#15 = LONG_SITE_NAME_0) THEN site_0 ..., PushedFilters: [IsNotNull(test_spec)], ReadSchema: struct<test_spec:string>
      :              :- ReusedSubquery SubqueryBroadcast dynamicpruning#54, [0], [site_id#1], true, [id=#335]
      :              +- ReusedSubquery SubqueryBroadcast dynamicpruning#53, [0], [test_id#0], true, [id=#322]
      +- ShuffleQueryStage 0, Statistics(sizeInBytes=942.0 B, rowCount=34, ColumnStat: N/A, isRuntime=true)
         +- GpuColumnarExchange gpusinglepartitioning$(), EXECUTOR_BROADCAST, [plan_id=238], [loreId=17]
            +- GpuCoalesceBatches targetsize(1073741824), [loreId=16]
               +- GpuFilter ((gpuisnotnull(test_id#0) AND gpuisnotnull(site_id#1)) AND site_id#1 INSET site_0, site_1), [loreId=15]
                  +- GpuFileGpuScan parquet [test_id#0,site_id#1] Batched: true, DataFilters: [isnotnull(test_id#0), isnotnull(site_id#1), site_id#1 IN (site_0,site_1)], Format: Parquet, Location: InMemoryFileIndex[file:/tmp/myth/PARQUET_DATA/tests], PartitionFilters: [], PushedFilters: [IsNotNull(test_id), IsNotNull(site_id), In(site_id, [site_0,site_1])], ReadSchema: struct<test_id:string,site_id:string>

Note the +- CollectLimit 3000001. It appears to be the result of pruning (sub)partitions as part of DPP.

It would be good to understand this optimization better, and the expected behaviour if there were more records/partitions than the limit.

revans2 commented 1 day ago

I really want to understand a few things about this feature. The main one is that it looks like we might produce an incorrect answer if the aggregation returns more results than the limit allows for. So essentially what is happening is that we have a table that is small enough it is being turned into a broadcast join. That broadcast is then turned into a single partition task through the AQE transition code. It goes through a hash aggregate to get distinct values and those distinct values are then limited and finally sent to the DPP filter. If that CollectLimit removed some values before it was turned into a new broadcast, then the DPP would produce the wrong answer because we are broadcasting an allow list. If we are missing something that should be allowed we will not read everything we should.

I suspect that the code that reads the data from the collect is looking at how many rows are returned, and if there are too many rows, then it might not do a DPP filter at all (note that they are probably asking for a limit of 3,000,000 rows, and collectLimit itself is adding the 1 extra row so that it can tell if the rows were cut off or not). It does not appear that AQE is looking at the number of rows in the broadcast, as that is only 34, so it would know that the limit is not needed, or might violate something. It also don't think that the code knows the number of partitions in the table and would be limiting things that way (as there is no where near 3 million partitions in this test). But who knows.

I just want to understand it so I know if we need to fix some bugs somewhere to deal with it.

The other question I have really is around performance, and if it is worth us trying to optimize collectLimitExec? We need to convert data back to rows and send it to the CPU so I don't want to write our own collectLimitExec, but We could insert in a GpuLocalLimit in front of it.