delta-io / delta-rs

A native Rust library for Delta Lake, with bindings into Python
https://delta-io.github.io/delta-rs/
Apache License 2.0
2.15k stars 383 forks source link

Large Memory Spike on Merge #2802

Closed rob-harrison closed 2 weeks ago

rob-harrison commented 3 weeks ago

Environment

Delta-rs version: 0.18.2

Binding: Python

Environment:


Bug

What happened:

  1. delta table with ±8 million rows, partitioned with largest partition 800k rows, size ±550GiB according to describe detail.
  2. table optimized with .compact() and checkpoint every 100 commits.
  3. merge operation with 1k rows and a single partition in predicate: part_col IN ('123456789')
  4. memory spikes from ±200MiB to ±14GiB on .execute() - see memory_profiler output:
Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
35    200.7 MiB    200.7 MiB           1       @profile
36                                             def _perform_merge(self, detections: List[Detection], class_blacklist: List[str] = None):
37                                                 # upsert detections to datastore
38    202.2 MiB      1.5 MiB           1           batch = build_pyarrow_batch(detections=detections, class_blacklist=class_blacklist)
39    202.2 MiB      0.0 MiB           1           logging.info(f'upserting batch of [{batch.num_rows}] detections to delta table')
40    202.2 MiB      0.0 MiB           1           if batch.num_rows > 0:
41    209.2 MiB      7.0 MiB           2               dt = get_delta_table(table_path=self.delta_table_path,
42    202.2 MiB      0.0 MiB           1                                    dynamo_table_name=self.dynamo_table_name)
43    209.2 MiB      0.0 MiB           1               h3_indices_column = batch.column('h3_indices')
44    209.2 MiB      0.0 MiB           1               partition_key_values = self._get_partition_keys(h3_indices=h3_indices_column.to_pylist())
45    209.2 MiB      0.0 MiB           1               merge_predicate = f'{TARGET_ALIAS}.{HD_MAP_FRAME_DETECTION_PARTITION_BY_COLUMN} in ({partition_key_values}) and {TARGET_ALIAS}.{HD_MAP_FRAME_DETECTION_KEY_COLUMN} = {SOURCE_ALIAS}.{HD_MAP_FRAME_DETECTION_KEY_COLUMN}'
46    209.2 MiB      0.0 MiB           1               update_predicate = f'{SOURCE_ALIAS}.{HD_MAP_FRAME_DETECTION_INFERENCE_AT_MS_COLUMN} >= {TARGET_ALIAS}.{HD_MAP_FRAME_DETECTION_INFERENCE_AT_MS_COLUMN}'
47  14466.7 MiB      0.0 MiB           1               metrics = (
48    209.2 MiB      0.0 MiB           2                   dt.merge(
49    209.2 MiB      0.0 MiB           1                       source=batch,
50    209.2 MiB      0.0 MiB           1                       predicate=merge_predicate,
51    209.2 MiB      0.0 MiB           1                       source_alias=SOURCE_ALIAS,
52    209.2 MiB      0.0 MiB           1                       target_alias=TARGET_ALIAS,
53    209.2 MiB      0.0 MiB           1                       large_dtypes=False
54                                                         )
55    209.2 MiB      0.0 MiB           1                   .when_matched_update_all(predicate=update_predicate)
56    209.2 MiB      0.0 MiB           1                   .when_not_matched_insert_all()
57  14466.7 MiB  14257.5 MiB           1                   .execute()
58                                                     )
59  14466.7 MiB      0.0 MiB           1               logging.info(f'merged with metrics {metrics}...')
60  14466.7 MiB      0.0 MiB           1               if dt.version() % OPTIMIZE_FREQUENCY == 0:
61                                                         try:
62                                                             self._optimize(dt)
63                                                         except Exception as e:
64                                                             logging.warning(f'error optimizing [{dt.table_uri}], will SKIP... [{e}]')
  1. Running on a pod with 16GiB, these spikes result in OOM Screenshot 2024-08-20 at 15 47 41

What you expected to happen: Memory to remain within reasonable limits

How to reproduce it:

More details:

rtyler commented 3 weeks ago

are you able to test with 0.19.0? That release contains a number of performance and memory improvements which will also benefit merge operations

ion-elgreco commented 3 weeks ago

Also if you still see issues with 0.19+, can you then use this branch and compile it: https://github.com/ion-elgreco/delta-rs/tree/debug/merge_explain

And then share the output that get's spitted in the stdout, I would like to see the plan with the executed stats

rob-harrison commented 3 weeks ago

seeing the same spike with 0.19.1 I'm afraid

2024-08-20 13:55:48,872 INFO     [detection_dao.py:39] upserting batch of [889] detections to delta table
2024-08-20 13:55:49,940 INFO     [detection_dao.py:97] found 1 partition keys
2024-08-20 13:56:22,970 INFO     [detection_dao.py:59] merged with metrics {'num_source_rows': 889, 'num_target_rows_inserted': 0, 'num_target_rows_updated': 889, 'num_target_rows_deleted': 0, 'num_target_rows_copied': 38913, 'num_output_rows': 39802, 'num_target_files_added': 1, 'num_target_files_removed': 1, 'execution_time_ms': 32323, 'scan_time_ms': 0, 'rewrite_time_ms': 32272}...

produces

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
35    125.9 MiB    125.9 MiB           1       @profile
36                                             def _perform_merge(self, detections: List[Detection], class_blacklist: List[str] = None):
37                                                 # upsert detections to datastore
38    137.0 MiB     11.1 MiB           1           batch = build_pyarrow_batch(detections=detections, class_blacklist=class_blacklist)
39    137.0 MiB      0.0 MiB           1           logging.info(f'upserting batch of [{batch.num_rows}] detections to delta table')
40    137.0 MiB      0.0 MiB           1           if batch.num_rows > 0:
41    166.3 MiB     29.3 MiB           2               dt = get_delta_table(table_path=self.delta_table_path,
42    137.0 MiB      0.0 MiB           1                                    dynamo_table_name=self.dynamo_table_name)
43    166.3 MiB      0.0 MiB           1               h3_indices_column = batch.column('h3_indices')
44    168.3 MiB      2.0 MiB           1               partition_key_values = self._get_partition_keys(h3_indices=h3_indices_column.to_pylist())
45    168.3 MiB      0.0 MiB           1               merge_predicate = f'{TARGET_ALIAS}.{HD_MAP_FRAME_DETECTION_PARTITION_BY_COLUMN} in ({partition_key_values}) and {TARGET_ALIAS}.{HD_MAP_FRAME_DETECTION_KEY_COLUMN} = {SOURCE_ALIAS}.{HD_MAP_FRAME_DETECTION_KEY_COLUMN}'
46    168.3 MiB      0.0 MiB           1               update_predicate = f'{SOURCE_ALIAS}.{HD_MAP_FRAME_DETECTION_INFERENCE_AT_MS_COLUMN} >= {TARGET_ALIAS}.{HD_MAP_FRAME_DETECTION_INFERENCE_AT_MS_COLUMN}'
47  12133.3 MiB      0.0 MiB           1               metrics = (
48    168.3 MiB      0.0 MiB           2                   dt.merge(
49    168.3 MiB      0.0 MiB           1                       source=batch,
50    168.3 MiB      0.0 MiB           1                       predicate=merge_predicate,
51    168.3 MiB      0.0 MiB           1                       source_alias=SOURCE_ALIAS,
52    168.3 MiB      0.0 MiB           1                       target_alias=TARGET_ALIAS,
53    168.3 MiB      0.0 MiB           1                       large_dtypes=False
54                                                         )
55    168.3 MiB      0.0 MiB           1                   .when_matched_update_all(predicate=update_predicate)
56    168.3 MiB      0.0 MiB           1                   .when_not_matched_insert_all()
57  12133.3 MiB  11965.0 MiB           1                   .execute()
58                                                     )
59  12133.3 MiB      0.0 MiB           1               logging.info(f'merged with metrics {metrics}...')
60  12133.3 MiB      0.0 MiB           1               if dt.version() % OPTIMIZE_FREQUENCY == 0:
61                                                         try:
62                                                             self._optimize(dt)
63                                                         except Exception as e:
64                                                             logging.warning(f'error optimizing [{dt.table_uri}], will SKIP... [{e}]')
Screenshot 2024-08-20 at 16 55 46

Also if you still see issues with 0.19+, can you then use this branch will do 🙏

ion-elgreco commented 3 weeks ago

@rob-harrison can you check the memory performance with this branch: https://github.com/ion-elgreco/delta-rs/tree/fix/set_greedy_mem_pool

rob-harrison commented 3 weeks ago

@ion-elgreco firstly - having not built rust before - the Building Custom Wheel notes were super clear and easy to follow 👍

  1. set_greedy_mem_pool -> same spikes/no improvement. See yellow (0.19.1) vs green (greedy) vs orange (merge_explain)
  2. merge_explain output attached from a couple of merge iterations - please let me know if you need more 🙏 merge-explain.txt Screenshot 2024-08-20 at 20 05 16
rtyler commented 3 weeks ago

Thanks for the detailed analysis @rob-harrison. Do you have an idea of what the working data set in memory for the merge might be? i.e. how many rows are trying to be merged? There have been some cases I have seen where the source/target data was simply too large for a merge to happen in memory with Python/Rust and we had to drop out to Spark to do the job since it has the facilities to spread that load across machines

rob-harrison commented 3 weeks ago

@rtyler please see typical merge metrics below:

2024-08-20 13:55:48,872 INFO     [detection_dao.py:39] upserting batch of [889] detections to delta table
2024-08-20 13:55:49,940 INFO     [detection_dao.py:97] found 1 partition keys
2024-08-20 13:56:22,970 INFO     [detection_dao.py:59] merged with metrics {'num_source_rows': 889, 'num_target_rows_inserted': 0, 'num_target_rows_updated': 889, 'num_target_rows_deleted': 0, 'num_target_rows_copied': 38913, 'num_output_rows': 39802, 'num_target_files_added': 1, 'num_target_files_removed': 1, 'execution_time_ms': 32323, 'scan_time_ms': 0, 'rewrite_time_ms': 32272}

We're talking between 1k-5k max source rows. The largest target partition is ±1million, with an avg ±120k I should mention there's also a zorder index on this table by (key, frame_key) - key is used in the merge predicate

merge_predicate = f't.h3_id_res in ({partition_key_values}) and t.key = s.key'

It doesn't feel to me like we should be anywhere near the limits requiring a move to Spark. I'll be interested to hear what insights you guys gather from the explain log - are we loading too many (all) partitions?

rob-harrison commented 3 weeks ago

formatted_merge-1.txt

hd-map-detections-consumer-5f494d87c6-vxj6w hd-map-detections-consumer 2024-08-20 16:36:16,884 INFO     [detection_dao.py:59] merged with metrics {'num_source_rows': 1559, 'num_target_rows_inserted': 0, 'num_target_rows_updated': 1559, 'num_target_rows_deleted': 0, 'num_target_rows_copied': 1143273, 'num_output_rows': 1144832, 'num_target_files_added': 2, 'num_target_files_removed': 2, 'execution_time_ms': 78484, 'scan_time_ms': 0, 'rewrite_time_ms': 40094}...

going over the attached plan from the above merge and reading it backwards (correct?), the following seems apparent:

MetricObserverExec id=merge_target_count,
metrics=[num_target_rows=7380466]\n

DeltaScan,
metrics=[files_scanned=69,
files_pruned=0]\n
  1. we start with a DeltaScan of the entire table (all 62 partitions with no pruning) and load the entire table to memory as target (7380466 rows)
MetricObserverExec id=merge_source_count,
metrics=[num_source_rows=1559]\n
  1. we load the source to memory (1559 rows)
HashJoinExec: mode=Partitioned,
join_type=Full,
on=[(key@0,
key@0)],
filter=h3_id_res9@0 = 617723686386925567,
metrics=[output_rows=7380466,
input_rows=7380466,
build_input_rows=1559,
output_batches=750,
input_batches=750,
build_input_batches=2,
build_mem_used=1381708,
join_time=4.703425867s,
build_time=401.443µs]\n
  1. we do a FULL hash join, resulting in (same) 7380466 rows and using 1381708 build_mem_used -> is that in k == 13GiB?
MetricObserverExec id=merge_output_count,
metrics=[num_target_updated_rows=1559,
num_target_inserted_rows=0,
num_target_deleted_rows=0,
num_copied_rows=1143273]\n

MergeBarrier,
metrics=[]\n
  1. we end up going thru the merge barrier and copy rows to partition (1143273)

If I'm reading the above correctly, the issue seems to stem from not pushing down the partition predicate to the initial delta scan.

ion-elgreco commented 3 weeks ago

@rob-harrison thanks for sharing the explain output, indeed files are not being pruned. And this is due to the IN predicate. Also reported here: https://github.com/delta-io/delta-rs/issues/2726

I will take a look at this!

rob-harrison commented 3 weeks ago

@ion-elgreco formatted_merge-2.txt

just tried changing partitions from IN list to series of OR conditions - can confirm pushdown works! (and memory as expected)

MetricObserverExec id=merge_target_count,
metrics=[num_target_rows=572175]\n

DeltaScan,
metrics=[files_pruned=73,
files_scanned=2]\n
ion-elgreco commented 3 weeks ago

@rob-harrison yeah generalize_filter() doesn't look at betweens or IN predicates yet, but working on that to address it, just need to get acquainted with this area of the code ^^

ion-elgreco commented 3 weeks ago

@rob-harrison I've pushed a PR, should land in 0.19.2 soonish

rob-harrison commented 3 weeks ago

many thanks @ion-elgreco 🙏