tatp22 / linformer-pytorch

My take on a practical implementation of Linformer for Pytorch.
https://arxiv.org/pdf/2006.04768.pdf
MIT License
407 stars 36 forks source link

Any performance test on different checkpoint level ? #5

Closed phongnhhn92 closed 4 years ago

phongnhhn92 commented 4 years ago

Hello, Thanks for the code ! I am testing your code with different checkpoint level. I see a massive drop in required GPU memory if I use "C1" or "C2" (about 50% in my case). It is weird that both C1 and C2 return the same allocated memory. So my first question is what is the different between C1 and C2 ?

As I check the checkpoint function here then increase the checkpoint level only affect the backward pass. So another question is: Does it hurt the overall performance if we use C2 instead of C0 ?

tatp22 commented 4 years ago

Hi @phongnhhn92! To answer your first question, the difference between the C1 and the C2 checkpoint levels is that in each Multihead attention layer, with the C2 checkpoint level, each individual head gets checkpointed, while in the C1checkpointing level, this optimization is not done. Now, as to why they give the same allocated memory, I think that I have a hypothesis on to why this may be. In each LinearAttentionHead, only 3 nn.Linear layers are created, each of (dim_d, dim_d) dimensions. These matrices are, most of the time, not going to be large (usually 64x64), and so, assuming you store all of the weights as float32, you will get a total of 3*64*64*4=48Mb of extra memory saved, per head. when running C2 vs C1. So if you set nhead=4, you will save 196Mb extra of memory, with this toy example.

As to the second question, yes, it will hurt performance (meaning time spent on computing). As to how much, I'm not sure. But this is because the gradients have to be recomputed for each backwards pass, which means that you're going to be using up more resources every time you compute your loss.backward() because your gradients are not cached.

phongnhhn92 commented 4 years ago

Thanks for your answer !