fkodom / dilated-attention-pytorch

(Unofficial) Implementation of dilated attention from "LongNet: Scaling Transformers to 1,000,000,000 Tokens" (https://arxiv.org/abs/2307.02486)
MIT License
50 stars 9 forks source link

Backward pass #6

Open Coluding opened 12 months ago

Coluding commented 12 months ago

Hi!

First of all, thanks for your great implementation. I think it is very awesome, I like it a lot. I was wondering if you have also implemented a backward pass for the model somewhere, since you have only shown the forward pass in this repo (please correct me if I am wrong). The reason why I am asking is because I want to train a reversible dilated Encoder model from scratch and your code seems very well suited for the attention mechanism.

Thanks in advance and kind regards!

fkodom commented 12 months ago

@Coluding Not sure I understand. Could you elaborate a bit?

The forward pass is implemented here, and the backward pass can be done automatically with PyTorch. Are you thinking there is a more efficient way to perform the backward pass? In that case, it could make sense to implement it manually here, too. Or maybe some other reason that I'm overlooking?

Coluding commented 12 months ago

Hi @fkodom !

I was just wondering if you have tested the backward pass and what implications for memory and sequence length it has.

Best regards!

fkodom commented 11 months ago

@Coluding Yes, the backward pass works and scales roughly the same as forward (linear with sequence length). Can test that with a slightly modified benchmark.py script:

Screen Shot 2023-11-29 at 9 35 07 AM Screen Shot 2023-11-29 at 9 32 11 AM
INFO:root:Benchmark dilated attention...
INFO:root:Sequence length 4096: (5.918e-02 ± 2.241e-04) s
INFO:root:Sequence length 8192: (5.902e-02 ± 3.026e-05) s
INFO:root:Sequence length 16384: (5.900e-02 ± 3.374e-05) s
INFO:root:Sequence length 32768: (5.903e-02 ± 3.294e-05) s
INFO:root:Sequence length 65536: (5.905e-02 ± 3.805e-05) s
INFO:root:Sequence length 131072: (5.898e-02 ± 3.360e-05) s
INFO:root:Sequence length 262144: (5.897e-02 ± 2.324e-05) s
INFO:root:Sequence length 524288: (5.896e-02 ± 2.756e-05) s

^^ In that script, I dynamically choose the batch size, so that the total number of tokens is constant for all sequence lengths. So, it's roughly constant runtime when the forward/backward pass scales linearly with sequence length.

I haven't explicitly checked for memory profiling, but AFAIK it should scale the same as forward as well.