Open hjmshi opened 4 months ago
What is the version of the GPU driver? Can you print the output of nvidia-smi in your environment?
We had a similar issue a few months ago https://github.com/mlcommons/algorithmic-efficiency/issues/497
Our error is deterministically reproducible.
Here is the nvidia-smi output in the Docker container:
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14 Driver Version: 550.54.14 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 Tesla V100-SXM2-16GB Off | 00000000:00:17.0 Off | 0 |
| N/A 65C P0 296W / 300W | 7213MiB / 16384MiB | 81% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 Tesla V100-SXM2-16GB Off | 00000000:00:18.0 Off | 0 |
| N/A 54C P0 237W / 300W | 9027MiB / 16384MiB | 82% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 2 Tesla V100-SXM2-16GB Off | 00000000:00:19.0 Off | 0 |
| N/A 53C P0 146W / 300W | 8951MiB / 16384MiB | 95% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 3 Tesla V100-SXM2-16GB Off | 00000000:00:1A.0 Off | 0 |
| N/A 64C P0 103W / 300W | 8921MiB / 16384MiB | 88% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 4 Tesla V100-SXM2-16GB Off | 00000000:00:1B.0 Off | 0 |
| N/A 58C P0 79W / 300W | 8991MiB / 16384MiB | 79% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 5 Tesla V100-SXM2-16GB Off | 00000000:00:1C.0 Off | 0 |
| N/A 57C P0 80W / 300W | 9007MiB / 16384MiB | 73% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 6 Tesla V100-SXM2-16GB Off | 00000000:00:1D.0 Off | 0 |
| N/A 52C P0 73W / 300W | 9091MiB / 16384MiB | 63% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 7 Tesla V100-SXM2-16GB Off | 00000000:00:1E.0 Off | 0 |
| N/A 60C P0 82W / 300W | 9053MiB / 16384MiB | 93% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+
Ok I see, sorry I did not catch that it was consistently OOMing.
We can try to clear the cache between trials.
Could you try adding _reset_cuda_mem()
in submission_runner.py
, line 542 ( https://github.com/mlcommons/algorithmic-efficiency/blob/main/submission_runner.py#L542) and let me know if that resolves the issue?
@chandramouli-sastry @pomonam @msaroufim do you have any other ideas?
Thank you for the suggestion! We've confirmed that adding _reset_cuda_mem()
solves this issue. However, we are now seeing that trial 5 consistently OOMs, even when run on its own (and regardless of whether _reset_cuda_mem()
is added). Here's what the error looks like:
I0301 22:57:47.301353 140438874752832 submission_runner.py:314] Starting training loop. [153/1900]
I0301 22:57:51.364732 140412236850944 logging_writer.py:48] [0] global_step=0, grad_norm=35.902340, loss=31.931053
I0301 22:57:51.381769 140438874752832 pytorch_nadamw_full_budget.py:296] 0) loss = 31.931, grad_norm = 35.902
I0301 22:57:51.775843 140438874752832 spec.py:321] Evaluating on the training split.
I0301 22:57:51.776754 140438874752832 input_pipeline.py:20] Loading split = train-clean-100
I0301 22:57:51.803838 140438874752832 input_pipeline.py:20] Loading split = train-clean-360
I0301 22:57:51.906127 140438874752832 input_pipeline.py:20] Loading split = train-other-500
I0301 22:58:06.571050 140438874752832 spec.py:333] Evaluating on the validation split.
I0301 22:58:06.572633 140438874752832 input_pipeline.py:20] Loading split = dev-clean
I0301 22:58:06.576136 140438874752832 input_pipeline.py:20] Loading split = dev-other
I0301 22:58:17.383088 140438874752832 spec.py:349] Evaluating on the test split.
I0301 22:58:17.384322 140438874752832 input_pipeline.py:20] Loading split = test-clean
I0301 22:58:22.858831 140438874752832 submission_runner.py:414] Time since start: 35.56s, Step: 1, {'train/ctc_loss': 31.21902909067492, 'train/wer': 2.106045848546352, 'validation/ctc_loss': 3
0.200058752570424, 'validation/wer': 1.494539661082412, 'validation/num_examples': 5348, 'test/ctc_loss': 30.30062109410089, 'test/wer': 1.5860296955294213, 'test/num_examples': 2472, 'score': 4.08172893524
1699, 'total_duration': 35.55763244628906, 'accumulated_submission_time': 4.081728935241699, 'accumulated_eval_time': 31.082834243774414, 'accumulated_logging_time': 0}
I0301 22:58:22.872175 140410752063232 logging_writer.py:48] [1] accumulated_eval_time=31.082834, accumulated_logging_time=0, accumulated_submission_time=4.081729, global_step=1, preemption_count=0, score=4.
081729, test/ctc_loss=30.300621, test/num_examples=2472, test/wer=1.586030, total_duration=35.557632, train/ctc_loss=31.219029, train/wer=2.106046, validation/ctc_loss=30.200059, validation/num_examples=534
8, validation/wer=1.494540
I0301 22:58:23.482473 140438874752832 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/shampoo/tmp/librispeech_conformer_pytorch/trial_5/checkpoint_1.
I0301 22:58:25.199719 140410533967616 logging_writer.py:48] [1] global_step=1, grad_norm=35.834045, loss=31.748371
I0301 22:58:25.202767 140438874752832 pytorch_nadamw_full_budget.py:296] 1) loss = 31.748, grad_norm = 35.834
I0301 22:58:26.362531 140410752063232 logging_writer.py:48] [2] global_step=2, grad_norm=43.194149, loss=31.490442
I0301 22:58:26.365588 140438874752832 pytorch_nadamw_full_budget.py:296] 2) loss = 31.490, grad_norm = 43.194
Traceback (most recent call last):
File "submission_runner.py", line 698, in <module>
app.run(main)
File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "submission_runner.py", line 666, in main
score = score_submission_on_workload(
File "submission_runner.py", line 577, in score_submission_on_workload
timing, metrics = train_once(workload, workload_name,
File "submission_runner.py", line 336, in train_once
optimizer_state, model_params, model_state = update_params(
File "/algorithmic-efficiency/submissions/baseline_submission/pytorch_nadamw_full_budget.py", line 276, in update_params
loss.backward()
File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 492, in backward
torch.autograd.backward(
File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.51 GiB. GPU 4 has a total capacty of 15.77 GiB of which 767.31 MiB is free. Process 610974 has 15.01 GiB memory in use. Of the allocated
memory 6.26 GiB is allocated by PyTorch, and 7.79 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documenta
tion for Memory Management and PYTORCH_CUDA_ALLOC_CONF
This trial corresponds to the following hyperparameter set:
{
"dropout_rate": 0.1,
"label_smoothing": 0.0,
"learning_rate": 0.0017486387539278373,
"one_minus_beta1": 0.06733926164,
"beta2": 0.9955159689799007,
"weight_decay": 0.08121616522670176,
"warmup_factor": 0.02
}
We observed that changing dropout_rate
to 0.0 no longer OOMs, as does halving the batch size. However, we haven't been able to run the baseline with this set of hyperparameters as-is.
Ok thanks for the update. I will make a fix to reset the cuda mem in our repo. We will try to reproduce and investigate the dropout issue on our end.
Also, I assume you're running these to produce some baseline logs? I checked in logs on our dev branch for under prize_qualification_baselines/external_tuning/logs
. These were produced with JAX fyi which is probably why we haven't noticed this issue. But I hope this will unblock you for the time being.
Thank you so much! Yes, we wanted to sanity check the baseline on our setup. Thank you for pointing us towards the logs, they will indeed be helpful :)
It seems like @pomonam was unable to reproduce this issue with dropout.
We probably have a fix though, could you modify https://github.com/mlcommons/algorithmic-efficiency/blob/main/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py and pass in inplace=True
in the nn.Dropout initializations. This should reduce the memory footprint of the dropout layer.
If this fixes things for you, we'll make this change as well.
From offline discussion w @anana10c and team it sounds like setting inplace=True
resolved the OOM issue.
@pomonam could you send over a PR with
_reset_cuda_mem()
addition at the beginning of each trial.inplace=True
in the PyTorch nn.Dropout layer initializations. Sorry for the delayed response - unfortunately, we've found that while inplace=True
lets the trial 5 run continue for longer, it still ends up OOMing at around iteration 3600.
We consistently observe an OOM error when running the one of the NAdamW baselines on LibriSpeech Conformer with multiple trials in PyTorch on 8 V100s with 16GB each. This is run for the external ruleset. The first trial will successfully run through, but any subsequent trial will OOM.
If we try to resume a multi-trial run, we will observe a NCCL error. This occurs even if we delete the
trial_2
folder (but thetrial_1
folder remains intact).Description
As discussed above, we will observe OOM when running LibriSpeech Conformer with the NAdamW baseline with multiple trials on 8 V100s with 16GB each. This is an example of an OOM we observe on the subsequent trial:
Alternatively, if we try to resume the multi-trial run, we will observe the NCCL error:
cc @anana10c @mikerabbat @tsunghsienlee @yuchenhao @shintaro-iwasaki
Steps to Reproduce
In the Docker container, run:
Source or Possible Fix
We are not aware of a possible fix for this issue. We suspect there may be a memory leak in the PyTorch LibriSpeech Conformer workload. Please let us know how to proceed. Thanks in advance!