fkodom / dilated-attention-pytorch

(Unofficial) Implementation of dilated attention from "LongNet: Scaling Transformers to 1,000,000,000 Tokens" (https://arxiv.org/abs/2307.02486)
MIT License
50 stars 9 forks source link

Running Time and Other Questions #2

Closed MHarris021 closed 1 year ago

MHarris021 commented 1 year ago

@fkodom this is great work!!! :smile: I'm very interested in trying out what you've built. I'm hoping that you could answer some of my questions. I saw that you were using a GTX 2080 for the GPU, what were it's memory limits? Based on your ability to process 32M Context Tokens, can you extrapolate how much memory would be required for the full 1B Context Tokens? How long did it take to run? I'm trying to get an upfront idea of the costs involved with using this technique. Did you choose not to attempt distributed training based upon the fact that you only had 1 GPU available to you, or was there something particularly challenging with creating a Distributed version? Thank you in advance for taking the time to read this issue and your contributions to LLMs.

fkodom commented 1 year ago

Hi, @MHarris021! Trying to answer all of your questions above 😅

I saw that you were using a GTX 2080 for the GPU, what were it's memory limits?

It has 11 GB VRAM.

can you extrapolate how much memory would be required for the full 1B Context Tokens?

Complexity and memory footprint are $O(N d)$ where $N$ is the sequence length and $d$ is the hidden dimension. So, as usual, the memory footprint is also dependent on model size. But if 32M tokens require about 10 GB VARM, then 1B tokens would require roughly 310 GB VRAM -- hence the need for distributed computing.

^^ The authors aren't clear what hyperparameters (i.e. hidden dimension, number of heads) were used in their version of the 1B-token benchmark. I suspect those parameters were very small, otherwise they would need quite a few A100 80GB GPUs just to pull off the benchmark. But their point was to show it's feasible to scale to 1B tokens, not actually train a model with 1B context windows.

Did you choose not to attempt distributed training based upon the fact that you only had 1 GPU available to you, or was there something particularly challenging with creating a Distributed version?

So far, this is purely for cost reasons. This is a side project, so I don't have a large compute budget. 😂 There's nothing inherently difficult about the distributed training -- libraries like huggingface/accelerate and Lightning-AI/lightning make distributed jobs fairly easy IMO. Here's an example training script from another project, which uses lightning and is compatible with local or distributed training.

I was planning to write a similar script for LongNet. But it would likely use a smaller toy dataset, so I can easily train without the cloud costs. 😂

MHarris021 commented 1 year ago

Hi @fkodom , thank you for such a swift response. A colleague of mine and I are trying to replicate your results. We think that you have something very useful here and it's much easier to follow than the implementation created by kyegomez. Would you be interested in working with us, nothing paid yet, purely research in our free time, but...

fkodom commented 1 year ago

@MHarris021 Interesting, I wasn't aware of kyegomez implementation. I agree, a bit difficult to tell what's going on there. Just from eyeing it over, I'm certain our two implementations are not equivalent. Operations like these scale very badly on GPUs, and I think there are quite a few other, more fundamental differences. They have definitely done a better job of socializing their implementation. 😂

What is your goal in trying to replicate the repo? Aside from educational purposes, which is usually why I work on these things haha. 🙃

MHarris021 commented 1 year ago

@fkodom

current project goals are:

Also, i extended you an invitation to a private repo on github. if you aren't interested at this time, no worries

fkodom commented 1 year ago

Training a 1B-parameter model with 1B context window would require an astronomical cloud budget. 😰

To confirm/verify my results, does it help to look at the benchmark script? https://github.com/fkodom/dilated-attention-pytorch/blob/87cda7579874b6485ea81a742b6a0dc51ffad6cc/benchmark.py

MHarris021 commented 1 year ago

@fkodom yes, we were able to successfully verify your results. we used an RTX A5000 w/ 24GB of VRam, took us about 10 minutes once we got everything setup. We're now benchmarking against 64M tokens to see if it is handled by the 24 GB of Vram on the GPU.

MHarris021 commented 1 year ago

64 Million tokens were processed successfully! runtimes increased by a factor of 2, if i modified the benchmark correctly. Now, processing in parallel (multiple GPUs) will need to be tested. benchmark 64M Tokens

fkodom commented 1 year ago

@MHarris021 good to hear 👍

MHarris021 commented 1 year ago

@fkodom We've been able to expand to 288 Million (1.125x 2**28) tokens on a single NVidia A100 (80GB VRam) on Runpod. This is the maximum number of tokens that seem to fit into the gpu memory. We are noticing a significant increase in the base processing time every time the number of tokens doubles. The note that you made about the smaller dilated sequence lengths taking longer to process also becomes more pronounced. I'm assuming that this is expected based upon the processing time per token. We estimate that we would need to distribute the benchmarking to 4 A100's (320GB of VRam) in order to process the 1 Billion context tokens. This was inline with your previous prediction. I suspect that the distributed processing will stay similar to the 288 Million tokens as the processing should proceed in parallel, so even though the token count will go up by 4x the total time should stay constant as 4 GPU's will be processing in parallel. There may be some increases due to coordination overhead. This also suggests that we may be able to get similar results by using 8 GPU's each with 40GB of VRam and that GPU Memory may be the primary bottleneck.

Benchmarking data follows:

benchmark-64M-tokens-2023-15-08

benchmark-128M-tokens-2023-15-08

benchmark-256M-tokens-2023-15-08

benchmark-288M-tokens-2023-15-08

fkodom commented 1 year ago

@MHarris021 Closing this issue for now. Feel free to open another, if you have questions about a particular part of the code.