mlcommons / algorithmic-efficiency

MLCommons Algorithmic Efficiency is a benchmark and competition measuring neural network training speedups due to algorithmic improvements in both training algorithms and models.
https://mlcommons.org/en/groups/research-algorithms/
Apache License 2.0
333 stars 69 forks source link

Criteo jax OOMs on host #357

Closed priyakasimbeg closed 1 year ago

priyakasimbeg commented 1 year ago

Criteo jax runs out of host memory after about 5K - 7K steps.

Description

Total RAM on the machine: 250GB

Traceback:

I0318 07:48:39.290826 139788360083264 checkpoints.py:356] Saving checkpoint at step: 6925
I0318 07:49:18.171090 139788360083264 checkpoints.py:317] Saved checkpoint at /experiment_runs/timing/criteo1tb_jax/trial_1/checkpoint_6925
I0318 07:49:18.647214 139788360083264 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/timing/criteo1tb_jax/trial_1/checkpoint_6925.
I0318 07:50:22.682096 139483971307264 logging_writer.py:48] [7000] global_step=7000, grad_norm=0.0072265081107616425, loss=0.12346725910902023
I0318 07:52:55.673196 139483434436352 logging_writer.py:48] [7100] global_step=7100, grad_norm=0.0090188542380929, loss=0.12465228140354156
I0318 07:58:27.361398 139788360083264 spec.py:298] Evaluating on the training split.
I0318 08:09:34.849290 139788360083264 spec.py:310] Evaluating on the validation split.
I0318 08:13:38.652261 139788360083264 spec.py:326] Evaluating on the test split.
I0318 08:16:59.559698 139788360083264 submission_runner.py:362] Time since start: 28759.92s,    Step: 7124,     {'train/loss': 0.12170044394100414, 'validation/loss': 0.12363606741573034, 'validation/num_examples': 89000000, 'test/loss': 0.1263589119942319, 'test/num_examples': 89274637}
I0318 08:16:59.582683 139483971307264 logging_writer.py:48] [7124] global_step=7124, preemption_count=0, score=9349.471796, test/loss=0.126359, test/num_examples=89274637, total_duration=28759.921956, train/loss=0.121700, validation/loss=0.123636, validation/num_examples=89000000
I0318 08:17:05.913339 139788360083264 checkpoints.py:356] Saving checkpoint at step: 7124
Traceback (most recent call last):
  File "submission_runner.py", line 602, in <module>
    app.run(main)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "submission_runner.py", line 575, in main
    score = score_submission_on_workload(workload,
  File "submission_runner.py", line 510, in score_submission_on_workload
    timing, metrics = train_once(workload, global_batch_size,
  File "submission_runner.py", line 377, in train_once
    checkpoint_utils.save_checkpoint(
  File "/algorithmic-efficiency/algorithmic_efficiency/checkpoint_utils.py", line 225, in save_checkpoint
    flax_checkpoints.save_checkpoint(
  File "/usr/local/lib/python3.8/dist-packages/flax/training/checkpoints.py", line 366, in save_checkpoint
    target = serialization.msgpack_serialize(target)
  File "/usr/local/lib/python3.8/dist-packages/flax/serialization.py", line 334, in msgpack_serialize
    return msgpack.packb(pytree, default=_msgpack_ext_pack, strict_types=True)
  File "/usr/local/lib/python3.8/dist-packages/msgpack/__init__.py", line 38, in packb
    return Packer(**kwargs).pack(o)
  File "msgpack/_packer.pyx", line 294, in msgpack._cmsgpack.Packer.pack
  File "msgpack/_packer.pyx", line 304, in msgpack._cmsgpack.Packer.pack

Steps to Reproduce

docker run -t -d \
-v /home/kasimbeg/data/:/data/ \
-v /home/kasimbeg/experiment_runs/:/experiment_runs \
-v /home/kasimbeg/experiment_runs/logs:/logs \
--gpus all --ipc=host \
us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/base_image \
-d criteo1tb \
-f jax \
-s reference_algorithms/target_setting_algorithms/jax_nadamw.py \
-w criteo1tb \
-t reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json \
-e timing \
-m 10666 \
-b true

Source or Possible Fix

The RAM usage at the first step is already quite high (~150GB, compared to ~100GB for the pytorch workload). So far we've ruled out checkpointing function by testing manual deletion of variables inside checkpointing method. I noticed the RAM usage significantly increases just from initializing and iterating over the dataset, 50GB from initialization and at a rate of about 20GB per 1000 batches. Interestingly, the pytorch workload does not OOM on host. For comparison, over the course of a run the RAM use increases from 100 GB (after initialization etc) to 150 GB.

priyakasimbeg commented 1 year ago

I profiled the jax and pytorch workloads without docker for comparison. It seems like the memory increases between evals, but mostly during evals when the model is evaluated on the train_eval, validation and test sets. The biggest jump for both workloads is when it calls eval for the first time (100GB vs 50GB for Jax vs Pytorch). The following biggest jumps in usage are during evals (~15GB for Jax vs ~5GB for pytorch initially), this number seems to slightly decrease over time.

Jax:

I0321 20:50:12.885326 140324442691392 submission_runner.py:231] Starting train once: RAM USED (GB) 3.931410432                                                                     
I0321 20:50:12.885688 140324442691392 submission_runner.py:240] After Initializing dataset: RAM USED (GB) 3.931410432                                                              
I0321 20:50:18.567030 140324442691392 submission_runner.py:252] After Initializing model: RAM USED (GB) 7.714209792                                                                
I0321 20:50:21.465607 140324442691392 submission_runner.py:261] After Initializing metrics bundle: RAM USED (GB) 7.714996224                                                       
I0321 20:50:22.133624 140324442691392 submission_runner.py:312] Before starting training loop and logger metrics bundle: RAM USED (GB) 7.751073792                                 
I0321 20:52:38.932373 140324442691392 submission_runner.py:334] After dataselection batch at step 0: RAM USED (GB) 45.78729984                                                     
I0321 20:52:59.151939 140324442691392 submission_runner.py:372] Before eval at step 1: RAM USED (GB) 53.63308544                                                                                                                                                                                         
I0321 21:10:50.153927 140324442691392 submission_runner.py:391] After eval at step 1: RAM USED (GB) 147.673493504                                                                  
I0321 21:20:33.790335 140324442691392 submission_runner.py:372] Before eval at step 487: RAM USED (GB) 153.171914752                                                               
I0321 21:38:35.544191 140324442691392 submission_runner.py:391] After eval at step 487: RAM USED (GB) 170.46554624                                                                
I0321 21:48:17.558385 140324442691392 submission_runner.py:372] Before eval at step 967: RAM USED (GB) 173.58379008                                                               
I0321 22:05:10.913231 140324442691392 submission_runner.py:391] After eval at step 967: RAM USED (GB) 184.060829696                                                                                                    
I0321 22:14:55.042828 140324442691392 submission_runner.py:372] Before eval at step 1449: RAM USED (GB) 183.876345856

pytorch

I0321 21:36:13.839403 139839682864960 submission_runner.py:230] Starting train once: RAM USED (GB) 5.049655296                                                                     
I0321 21:36:13.839647 139839682864960 submission_runner.py:239] After Initializing dataset: RAM USED (GB) 5.049954304 
I0321 21:36:26.142259 139839682864960 submission_runner.py:251] After Initializing model: RAM USED (GB) 14.66904576                                                                                                                             
I0321 21:36:26.143074 139839682864960 submission_runner.py:260] After Initializing metrics bundle: RAM USED (GB) 14.66904576                                                       
I0321 21:38:38.603623 139839682864960 submission_runner.py:333] After dataselection batch at step 0: RAM USED (GB) 53.544239104                                                    
I0321 21:38:40.801852 139839682864960 submission_runner.py:371] Before eval at step 1: RAM USED (GB) 57.1206656                                                                    
I0321 21:57:13.853280 139839682864960 submission_runner.py:390] After eval at step 1: RAM USED (GB) 102.980071424                                                                                                                              
I0321 22:06:23.800114 139839682864960 submission_runner.py:371] Before eval at step 467: RAM USED (GB) 112.054808576
I0321 22:26:06.325435 139839682864960 submission_runner.py:390] After eval at step 467: RAM USED (GB) 116.150771712
I0321 22:35:20.300642 139839682864960 submission_runner.py:371] Before eval at step 909: RAM USED (GB) 117.905657856
I0321 22:54:42.722589 139839682864960 submission_runner.py:390] After eval at step 909: RAM USED (GB) 120.790904832
priyakasimbeg commented 1 year ago

Reducing the eval batch size for jax so that its equal to the pytorch eval batch size seems to help close the gap:

I0321 23:54:21.646273 139729336072000 submission_runner.py:299] Saving flags to /home/kasimbeg/experiments/debug_criteo_small_eval_batch/criteo1tb_jax/trial_1/flags_0.json.
I0321 23:54:21.674989 139729336072000 submission_runner.py:304] After checkpoint and logger metrics bundle: RAM USED (GB) 7.71524608
I0321 23:54:21.675219 139729336072000 submission_runner.py:311] Before starting training loop and logger metrics bundle: RAM USED (GB) 7.71524608
I0321 23:54:21.675282 139729336072000 submission_runner.py:312] Starting training loop.
I0321 23:57:01.052532 139729336072000 submission_runner.py:333] After dataselection batch at step 0: RAM USED (GB) 45.890441216
I0321 23:57:21.980293 139525796914944 logging_writer.py:48] [0] global_step=0, grad_norm=7.941814422607422, loss=0.7395287752151489
I0321 23:57:21.994342 139729336072000 submission_runner.py:350] After update parameters step 0: RAM USED (GB) 53.810757632
I0321 23:57:21.994553 139729336072000 submission_runner.py:371] Before eval at step 1: RAM USED (GB) 53.810757632
I0321 23:57:21.994624 139729336072000 spec.py:298] Evaluating on the training split.
I0322 00:06:08.399080 139729336072000 spec.py:310] Evaluating on the validation split.
I0322 00:10:16.798121 139729336072000 spec.py:326] Evaluating on the test split.
I0322 00:14:07.295299 139729336072000 submission_runner.py:380] Time since start: 180.32s,      Step: 1,        {'train/loss': 0.7395033004794426, 'validation/loss': 0.7445032359550562, 'validation/num_examples': 89000000, 'test/loss': 0.7416469248707223, 'test/num_examples': 89274637}
I0322 00:14:07.295766 139729336072000 submission_runner.py:390] After eval at step 1: RAM USED (GB) 95.224762368
priyakasimbeg commented 1 year ago

The interleave, shuffle and prefetch tf operations in the pipeline use RAM. I reduced the number of prefetch batches 2 and tried to set it to AUTOTUNE but that did not slow down the memory growth.

priyakasimbeg commented 1 year ago

Using tcmalloc instead of libmalloc (see tf issue) seems to eliminate the memory increases on standalone data iteration test:

image

Will verify w full run.

Note to use tcmalloc, install:

sudo apt-get install libtcmalloc-minimal4

Then run program with:

$LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4 python ...

The libtmalloc path passed to the $LD_PRELOAD path may be different on your VM.

priyakasimbeg commented 1 year ago

Tested on full run and the memory usage flattened at around step 400 to ~122GB. It stayed at about ~122GB for the rest of the run:

Before eval train evaluation at step 8000: RAM USED (GB) 122.759159808
Before iteration over eval_train split for eval at step 8000 over 170: RAM USED (GB) 122.759159808
After iteration over eval_train split for eval at step 8000 over 170: RAM USED (GB) 122.744086528
Before eval train at step 8000: RAM USED (GB) 122.744086528
Before valid evaluation at step 8000: RAM USED (GB) 122.744086528
Before iteration over validation split for eval at step 8000 over 170: RAM USED (GB) 122.744086528
After iteration over validation split for eval at step 8000 over 170: RAM USED (GB) 122.706014208
After valid evaluation at step 8000: RAM USED (GB) 122.706014208
Before test evaluaton at step 8000: RAM USED (GB) 122.706014208
Before iteration over test split for eval at step 8000 over 171: RAM USED (GB) 122.706014208
After iteration over test split for eval at step 8000 over 171: RAM USED (GB) 122.811183104
After test evaluaton at step 8000: RAM USED (GB) 122.811183104

Action items:

priyakasimbeg commented 1 year ago

Fixed w https://github.com/mlcommons/algorithmic-efficiency/pull/376