Smerity / sha-rnn

Single Headed Attention RNN - "Stop thinking with your head"
1.18k stars 133 forks source link

how to control GPU ram usage #9

Open jprobichaud opened 4 years ago

jprobichaud commented 4 years ago

Thanks for sharing this code! I'd like to try on my own training dataset, but I keep getting GPU OOM problems:

RuntimeError: CUDA out of memory. Tried to allocate 11.59 GiB (GPU 0; 11.91 GiB total capacity; 0 bytes already allocated; 11.43 GiB free; 0 bytes cached)

I've cut down the batch size to 8, emb size to 512, nhid to 2048 and nlayers to 2 and I still get the exact same message.

My training data set is 3.3GB (that's 1/10 of the data I would like to throw at it) so I'm already way over the enwik8 dataset (173MB) so I wonder where I should tweak the model/code...

Smerity commented 4 years ago

At a guess the likely issue is the vocabulary size of your dataset. What's the vocabulary size you have for your 3.3GB dataset? The dataset isn't actually kept on the GPU device's memory so shouldn't impact the model size.

The solutions would include an adaptive softmax, which this codebase used to have but which I removed, or to reduce the vocabulary size through wordpieces or similar.

If you have a large vocabulary then the GPU memory will balloon quite rapidly as it's required for the softmax output of each and every timestep.

jprobichaud commented 4 years ago

It's a char-based lm, and the data is lowercased, so aside the 26 letters, some apostrophes and dashes plus some monetary symbols, there is nothing else. The vocab size is less than 100.

How can i diagnose this issue?

On Mon., Dec. 9, 2019, 7:48 p.m. Stephen Merity, notifications@github.com wrote:

At a guess the likely issue is the vocabulary size of your dataset. What's the vocabulary size you have for your 3.3GB dataset? The dataset isn't actually kept on the GPU device's memory so shouldn't impact the model size.

The solutions would include an adaptive softmax, which this codebase used to have but which I removed, or to reduce the vocabulary size through wordpieces or similar.

If you have a large vocabulary then the GPU memory will balloon quite rapidly as it's required for the softmax output of each and every timestep.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/Smerity/sha-rnn/issues/9?email_source=notifications&email_token=ACGTL2Z5WIG2YDAOKDJDUSLQX3RN3A5CNFSM4JXBSM52YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEGLHIKY#issuecomment-563508267, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACGTL27DVIMDPKSTAKTBAPTQX3RN3ANCNFSM4JXBSM5Q .

Smerity commented 4 years ago

That's quite odd. Are you able to replicate the initial results on enwik8? I would try doing that first. My GPU only had ~12GB of RAM so there's no reason you shouldn't be able to do this as far as I'm aware assuming your data is character level. If you can replicate then try a 100MB chunk of your dataset and if that still works then potentially I do have a line of code that unexpectedly puts the dataset in GPU memory. If that's the case it's an easy fix of finding that line (like a .cuda()), removing that from the massive dataset, and putting a .cuda() when the snippets of data are loaded for training.

jprobichaud commented 4 years ago

I was able to reproduce the enwik8 results without problem (not the exact BPC published, but very close)

I will try with a smaller sample of my dataset and see. If need be, I'll go and see if there is a .cuda() put in the wrong place.

I had added some print statements in the data loading method, here are the numbers I'm getting (for the entire dataset): train.txt, 1555434404 tokens valid.txt, 1978645700 tokens test.txt, 2375699684 tokens

On Tue, Dec 10, 2019 at 2:07 PM Stephen Merity notifications@github.com wrote:

That's quite odd. Are you able to replicate the initial results on enwik8? I would try doing that first. My GPU only had ~12GB of RAM so there's no reason you shouldn't be able to do this as far as I'm aware assuming your data is character level. If you can replicate then try a 100MB chunk of your dataset and if that still works then potentially I do have a line of code that unexpectedly puts the dataset in GPU memory. If that's the case it's an easy fix of finding that line (like a .cuda()), removing that from the massive dataset, and putting a .cuda() when the snippets of data are loaded for training.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/Smerity/sha-rnn/issues/9?email_source=notifications&email_token=ACGTL2YX4AEGWS4DWHNMCT3QX7SIZA5CNFSM4JXBSM52YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEGQMNXY#issuecomment-564184799, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACGTL26UULVJQNW6EUSAOBLQX7SIZANCNFSM4JXBSM5Q .

Smerity commented 4 years ago

Ah, I was wrong. The dataset is loaded into memory - it was a previous version of the codebase I optimized that for sorry.

The fix is to take out the dataset transfer to GPU in batchify and add it to data and target in get_batch.

This may slow the training down a little, I'm not certain, as small batches of data will be shuffled back and forth between CPU and GPU, but it will allow you to train without having the dataset in GPU RAM. You'll obviously need to store it in CPU RAM however.

jprobichaud commented 4 years ago

Wonderful, thanks, that seems to do the trick!

With a smaller dataset and without the fix, I'm getting the following "throughput"

Dec 10 16:51:40 | epoch   0 |    10/ 1162 batches | lr 0.00003 | ms/batch 689.93 | loss  4.57 | ppl    96.23 | bpc    6.588
Dec 10 16:51:47 | epoch   0 |    20/ 1162 batches | lr 0.00005 | ms/batch 651.88 | loss  3.65 | ppl    38.31 | bpc    5.260
Dec 10 16:51:53 | epoch   0 |    30/ 1162 batches | lr 0.00008 | ms/batch 653.78 | loss  3.12 | ppl    22.68 | bpc    4.503
Dec 10 16:52:00 | epoch   0 |    40/ 1162 batches | lr 0.00010 | ms/batch 657.68 | loss  3.01 | ppl    20.19 | bpc    4.336
Dec 10 16:52:07 | epoch   0 |    50/ 1162 batches | lr 0.00013 | ms/batch 661.37 | loss  2.99 | ppl    19.88 | bpc    4.313
Dec 10 16:52:13 | epoch   0 |    60/ 1162 batches | lr 0.00015 | ms/batch 634.67 | loss  3.00 | ppl    20.03 | bpc    4.324
Dec 10 16:52:20 | epoch   0 |    70/ 1162 batches | lr 0.00018 | ms/batch 662.47 | loss  2.97 | ppl    19.54 | bpc    4.289
Dec 10 16:52:26 | epoch   0 |    80/ 1162 batches | lr 0.00020 | ms/batch 671.82 | loss  2.88 | ppl    17.74 | bpc    4.149
Dec 10 16:52:33 | epoch   0 |    90/ 1162 batches | lr 0.00023 | ms/batch 670.94 | loss  2.76 | ppl    15.81 | bpc    3.983
Dec 10 16:52:40 | epoch   0 |   100/ 1162 batches | lr 0.00025 | ms/batch 673.17 | loss  2.66 | ppl    14.26 | bpc    3.834
Dec 10 16:52:46 | epoch   0 |   110/ 1162 batches | lr 0.00028 | ms/batch 672.23 | loss  2.58 | ppl    13.18 | bpc    3.720
Dec 10 16:52:53 | epoch   0 |   120/ 1162 batches | lr 0.00030 | ms/batch 674.66 | loss  2.47 | ppl    11.80 | bpc    3.560
Dec 10 16:53:00 | epoch   0 |   130/ 1162 batches | lr 0.00033 | ms/batch 674.38 | loss  2.37 | ppl    10.70 | bpc    3.419
Dec 10 16:53:07 | epoch   0 |   140/ 1162 batches | lr 0.00035 | ms/batch 676.15 | loss  2.32 | ppl    10.15 | bpc    3.343
Dec 10 16:53:14 | epoch   0 |   150/ 1162 batches | lr 0.00038 | ms/batch 709.25 | loss  2.24 | ppl     9.42 | bpc    3.236

so 1.6 batches per sec.

With the larger dataset and the fix you suggested:

Dec 10 16:50:05 | epoch   0 |    10/94936 batches | lr 0.00003 | ms/batch 1056.88 | loss  4.57 | ppl    96.52 | bpc    6.593
Dec 10 16:50:15 | epoch   0 |    20/94936 batches | lr 0.00005 | ms/batch 938.93 | loss  3.65 | ppl    38.34 | bpc    5.261
Dec 10 16:50:21 | epoch   0 |    30/94936 batches | lr 0.00008 | ms/batch 673.42 | loss  3.11 | ppl    22.46 | bpc    4.489
Dec 10 16:50:28 | epoch   0 |    40/94936 batches | lr 0.00010 | ms/batch 677.31 | loss  3.01 | ppl    20.21 | bpc    4.337
Dec 10 16:50:35 | epoch   0 |    50/94936 batches | lr 0.00013 | ms/batch 683.31 | loss  2.99 | ppl    19.96 | bpc    4.319
Dec 10 16:50:42 | epoch   0 |    60/94936 batches | lr 0.00015 | ms/batch 688.04 | loss  3.00 | ppl    20.01 | bpc    4.323
Dec 10 16:50:49 | epoch   0 |    70/94936 batches | lr 0.00018 | ms/batch 711.40 | loss  2.99 | ppl    19.79 | bpc    4.307
Dec 10 16:50:56 | epoch   0 |    80/94936 batches | lr 0.00020 | ms/batch 713.90 | loss  2.88 | ppl    17.80 | bpc    4.154
Dec 10 16:51:03 | epoch   0 |    90/94936 batches | lr 0.00023 | ms/batch 714.09 | loss  2.75 | ppl    15.60 | bpc    3.964
Dec 10 16:51:10 | epoch   0 |   100/94936 batches | lr 0.00025 | ms/batch 717.38 | loss  2.67 | ppl    14.45 | bpc    3.853
Dec 10 16:51:17 | epoch   0 |   110/94936 batches | lr 0.00028 | ms/batch 713.74 | loss  2.58 | ppl    13.20 | bpc    3.722
Dec 10 16:51:25 | epoch   0 |   120/94936 batches | lr 0.00030 | ms/batch 711.81 | loss  2.49 | ppl    12.08 | bpc    3.595
Dec 10 16:51:31 | epoch   0 |   130/94936 batches | lr 0.00033 | ms/batch 682.43 | loss  2.43 | ppl    11.33 | bpc    3.502
Dec 10 16:51:38 | epoch   0 |   140/94936 batches | lr 0.00035 | ms/batch 672.41 | loss  2.34 | ppl    10.36 | bpc    3.372
Dec 10 16:51:45 | epoch   0 |   150/94936 batches | lr 0.00038 | ms/batch 713.46 | loss  2.27 | ppl     9.65 | bpc    3.270

So about 1.5 batches per sec.

Not bad. Both exps use --emsize 512 --nhid 4096 --nlayers 4 --batch_size 16

The large data set o GPU 0, the "small data set" runs on GPU 1 and nvidia-smi reports:

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     36533      C   python                                      6063MiB |
|    1     36623      C   python                                      8795MiB |
+-----------------------------------------------------------------------------+
Smerity commented 4 years ago

I'm so glad! Sorry about the wild goose / bug chase =]

It appears that the overhead isn't all that substantial which is reassuring. The technique of loading individual batches to GPU memory was the approach I used for WikiText-103 as RAM was scarce. Various optimizations could be made, such as loading a number of batches at the same time, but that's likely a little over the top. There are big gains to come from all directions as the model really deserves some optimization love.

For your experiment I would note that the embedding size of 512 will likely limit your model as that's the size of the LSTM hidden state as well. LSTMs are not as efficient when working with smaller hidden states due to the forget mask recurrence limiting their expressiveness. You should still get reasonable results but it may require some tweaking.

If you're interested in telling me more about what dataset / task you're exploring I'd love to hear it, online or offline :)