Closed young-chao closed 1 year ago
The bsz is actually the input batch_size.
Seems like the crash happens during the forward pass, and is not directly related to the transformers library. This does not seem like a bug but rather a discussion, feel free to ask on the forum.
System Info
When I use transformers' OPTModel to load the opt-13b model for training with Pytorch FSDP, I found that the whole training is limited by batch_size. Although FSDP has the ability to offload parameters to the CPU memory to reduce the pressure on the GPU memory, due to the impact of batch on the parameter scale of the forward phase, the GPU memory overflows when some parameters are initialized on the GPU.
Who can help?
@sgugger @ArthurZucker @younesbelkada
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Training code
Expected behavior
The shape of attn_weights is (bsz:100,self.num_heads:40,tgt_len:1024,src_len:1024). Even though its data type is fp16, its size has reached close to 8GB, which directly leads to gpu memory overflow.