XiangLi1999 / PrefixTuning

Prefix-Tuning: Optimizing Continuous Prompts for Generation
868 stars 158 forks source link

OOM error #8

Closed taineleau closed 2 years ago

taineleau commented 2 years ago

Hi, I tried the seq2seq prefixtuning and found:

RuntimeError: CUDA out of memory. Tried to allocate 1.20 GiB (GPU 0; 15.90 GiB total capacity; 4.63 GiB already allocated; 797.50 MiB free; 5.81 GiB reserved in total by PyTorch)

I run the expr on a 16GB GPU. Am I supposed to use a 32GB GPU instead? Thanks!

XiangLi1999 commented 2 years ago

Hi,

I used a 32GB GPU for the XSUM experiments. You could either switch to a GPU with larger memory, or you could reduce the bsz and increase the gradient_accumulation_steps.

StevenTang1998 commented 2 years ago

Hi, I used one GPU (Tesla V100 SXM2 32GB) and used the command in the homepage. However, I still had the OOM problem with bsz=16 or =12. If I set the bsz to 8 and the OOM will disappear. So, is the command in the homepage the command used to reproduce the paper?

XiangLi1999 commented 2 years ago

Hi,

it's the command to reproduce.

Could you check if you have --fp16 yes and whether this turn on half-precision? This should turn on half precision, so that bsz=16 could fit.

Side Note: I used AWS single GPU (I think it's A100) to run all XSUM experiments.

StevenTang1998 commented 2 years ago

I have --fp 16 yes and how can I see whether this turn on half-precision?

Side Note: the A100 in AWS has 40GB GPU memory rather than 32GB.

XiangLi1999 commented 2 years ago

Maybe check if your stdout contains this: Using native 16bit precision.

StevenTang1998 commented 2 years ago

Thanks! My stdout contains: Using native 16bit precision. It may be due to the GPU memory and I reduce the bsz to train the model.

By the way, what does the --mid_dim mean?

XiangLi1999 commented 2 years ago

It means the dim of the MLP's middle layer! (we use an MLP for re-parametrization.)

StevenTang1998 commented 2 years ago

Got it! Thanks for your answer!

taineleau commented 2 years ago

Thanks! My mistake. I just figured out I didn't change the bsz in the right way so that I still get OOM even if I made bsz as 1.