pbloem / former

Simple transformer implementation from scratch in pytorch.
http://peterbloem.nl/blog/transformers
MIT License
1.03k stars 170 forks source link

Calculation of memory required in "Going big" #34

Closed oleksandrasaskia closed 1 year ago

oleksandrasaskia commented 1 year ago

For a sequence length t, this is a dense matrix containing t2 elements. At standard 32-bit precision, and with t=1000 a batch of 16 such matrices takes up about 250Mb of memory.

Just a small question: how do we arrive at 250 Mb?

A single matrix would be 1 million elements, each of which is 4 bytes, is my understanding correct? So a single matrix would require 4M bytes, or 4 Mb. And then we would need 16 of them for a batch, correct? Then wouldn't this result into 64 Mb in total?

Just trying to check my understanding here. Thanks so much for the great article!

pbloem commented 1 year ago

Yes, you're totally right. Well spotted!

Not sure what happened. My guess is that I worked backward from the memory requirements of my model, and forgot to account for the heads.

oleksandrasaskia commented 1 year ago

All good, I just thought maybe I was missing some conceptual piece. Thanks for confirming!