Closed priyakasimbeg closed 1 year ago
I profiled the jax and pytorch workloads without docker for comparison. It seems like the memory increases between evals, but mostly during evals when the model is evaluated on the train_eval, validation and test sets. The biggest jump for both workloads is when it calls eval for the first time (100GB vs 50GB for Jax vs Pytorch). The following biggest jumps in usage are during evals (~15GB for Jax vs ~5GB for pytorch initially), this number seems to slightly decrease over time.
Jax:
I0321 20:50:12.885326 140324442691392 submission_runner.py:231] Starting train once: RAM USED (GB) 3.931410432
I0321 20:50:12.885688 140324442691392 submission_runner.py:240] After Initializing dataset: RAM USED (GB) 3.931410432
I0321 20:50:18.567030 140324442691392 submission_runner.py:252] After Initializing model: RAM USED (GB) 7.714209792
I0321 20:50:21.465607 140324442691392 submission_runner.py:261] After Initializing metrics bundle: RAM USED (GB) 7.714996224
I0321 20:50:22.133624 140324442691392 submission_runner.py:312] Before starting training loop and logger metrics bundle: RAM USED (GB) 7.751073792
I0321 20:52:38.932373 140324442691392 submission_runner.py:334] After dataselection batch at step 0: RAM USED (GB) 45.78729984
I0321 20:52:59.151939 140324442691392 submission_runner.py:372] Before eval at step 1: RAM USED (GB) 53.63308544
I0321 21:10:50.153927 140324442691392 submission_runner.py:391] After eval at step 1: RAM USED (GB) 147.673493504
I0321 21:20:33.790335 140324442691392 submission_runner.py:372] Before eval at step 487: RAM USED (GB) 153.171914752
I0321 21:38:35.544191 140324442691392 submission_runner.py:391] After eval at step 487: RAM USED (GB) 170.46554624
I0321 21:48:17.558385 140324442691392 submission_runner.py:372] Before eval at step 967: RAM USED (GB) 173.58379008
I0321 22:05:10.913231 140324442691392 submission_runner.py:391] After eval at step 967: RAM USED (GB) 184.060829696
I0321 22:14:55.042828 140324442691392 submission_runner.py:372] Before eval at step 1449: RAM USED (GB) 183.876345856
pytorch
I0321 21:36:13.839403 139839682864960 submission_runner.py:230] Starting train once: RAM USED (GB) 5.049655296
I0321 21:36:13.839647 139839682864960 submission_runner.py:239] After Initializing dataset: RAM USED (GB) 5.049954304
I0321 21:36:26.142259 139839682864960 submission_runner.py:251] After Initializing model: RAM USED (GB) 14.66904576
I0321 21:36:26.143074 139839682864960 submission_runner.py:260] After Initializing metrics bundle: RAM USED (GB) 14.66904576
I0321 21:38:38.603623 139839682864960 submission_runner.py:333] After dataselection batch at step 0: RAM USED (GB) 53.544239104
I0321 21:38:40.801852 139839682864960 submission_runner.py:371] Before eval at step 1: RAM USED (GB) 57.1206656
I0321 21:57:13.853280 139839682864960 submission_runner.py:390] After eval at step 1: RAM USED (GB) 102.980071424
I0321 22:06:23.800114 139839682864960 submission_runner.py:371] Before eval at step 467: RAM USED (GB) 112.054808576
I0321 22:26:06.325435 139839682864960 submission_runner.py:390] After eval at step 467: RAM USED (GB) 116.150771712
I0321 22:35:20.300642 139839682864960 submission_runner.py:371] Before eval at step 909: RAM USED (GB) 117.905657856
I0321 22:54:42.722589 139839682864960 submission_runner.py:390] After eval at step 909: RAM USED (GB) 120.790904832
Reducing the eval batch size for jax so that its equal to the pytorch eval batch size seems to help close the gap:
I0321 23:54:21.646273 139729336072000 submission_runner.py:299] Saving flags to /home/kasimbeg/experiments/debug_criteo_small_eval_batch/criteo1tb_jax/trial_1/flags_0.json.
I0321 23:54:21.674989 139729336072000 submission_runner.py:304] After checkpoint and logger metrics bundle: RAM USED (GB) 7.71524608
I0321 23:54:21.675219 139729336072000 submission_runner.py:311] Before starting training loop and logger metrics bundle: RAM USED (GB) 7.71524608
I0321 23:54:21.675282 139729336072000 submission_runner.py:312] Starting training loop.
I0321 23:57:01.052532 139729336072000 submission_runner.py:333] After dataselection batch at step 0: RAM USED (GB) 45.890441216
I0321 23:57:21.980293 139525796914944 logging_writer.py:48] [0] global_step=0, grad_norm=7.941814422607422, loss=0.7395287752151489
I0321 23:57:21.994342 139729336072000 submission_runner.py:350] After update parameters step 0: RAM USED (GB) 53.810757632
I0321 23:57:21.994553 139729336072000 submission_runner.py:371] Before eval at step 1: RAM USED (GB) 53.810757632
I0321 23:57:21.994624 139729336072000 spec.py:298] Evaluating on the training split.
I0322 00:06:08.399080 139729336072000 spec.py:310] Evaluating on the validation split.
I0322 00:10:16.798121 139729336072000 spec.py:326] Evaluating on the test split.
I0322 00:14:07.295299 139729336072000 submission_runner.py:380] Time since start: 180.32s, Step: 1, {'train/loss': 0.7395033004794426, 'validation/loss': 0.7445032359550562, 'validation/num_examples': 89000000, 'test/loss': 0.7416469248707223, 'test/num_examples': 89274637}
I0322 00:14:07.295766 139729336072000 submission_runner.py:390] After eval at step 1: RAM USED (GB) 95.224762368
The interleave, shuffle and prefetch tf operations in the pipeline use RAM. I reduced the number of prefetch batches 2 and tried to set it to AUTOTUNE but that did not slow down the memory growth.
Using tcmalloc instead of libmalloc (see tf issue) seems to eliminate the memory increases on standalone data iteration test:
Will verify w full run.
Note to use tcmalloc, install:
sudo apt-get install libtcmalloc-minimal4
Then run program with:
$LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4 python ...
The libtmalloc path passed to the $LD_PRELOAD path may be different on your VM.
Tested on full run and the memory usage flattened at around step 400 to ~122GB. It stayed at about ~122GB for the rest of the run:
Before eval train evaluation at step 8000: RAM USED (GB) 122.759159808
Before iteration over eval_train split for eval at step 8000 over 170: RAM USED (GB) 122.759159808
After iteration over eval_train split for eval at step 8000 over 170: RAM USED (GB) 122.744086528
Before eval train at step 8000: RAM USED (GB) 122.744086528
Before valid evaluation at step 8000: RAM USED (GB) 122.744086528
Before iteration over validation split for eval at step 8000 over 170: RAM USED (GB) 122.744086528
After iteration over validation split for eval at step 8000 over 170: RAM USED (GB) 122.706014208
After valid evaluation at step 8000: RAM USED (GB) 122.706014208
Before test evaluaton at step 8000: RAM USED (GB) 122.706014208
Before iteration over test split for eval at step 8000 over 171: RAM USED (GB) 122.706014208
After iteration over test split for eval at step 8000 over 171: RAM USED (GB) 122.811183104
After test evaluaton at step 8000: RAM USED (GB) 122.811183104
Action items:
Criteo jax runs out of host memory after about 5K - 7K steps.
Description
Total RAM on the machine: 250GB
Traceback:
Steps to Reproduce
Source or Possible Fix
The RAM usage at the first step is already quite high (~150GB, compared to ~100GB for the pytorch workload). So far we've ruled out checkpointing function by testing manual deletion of variables inside checkpointing method. I noticed the RAM usage significantly increases just from initializing and iterating over the dataset, 50GB from initialization and at a rate of about 20GB per 1000 batches. Interestingly, the pytorch workload does not OOM on host. For comparison, over the course of a run the RAM use increases from 100 GB (after initialization etc) to 150 GB.