Open Coluding opened 12 months ago
@Coluding Not sure I understand. Could you elaborate a bit?
The forward pass is implemented here, and the backward pass can be done automatically with PyTorch. Are you thinking there is a more efficient way to perform the backward pass? In that case, it could make sense to implement it manually here, too. Or maybe some other reason that I'm overlooking?
Hi @fkodom !
I was just wondering if you have tested the backward pass and what implications for memory and sequence length it has.
Best regards!
@Coluding Yes, the backward pass works and scales roughly the same as forward
(linear with sequence length). Can test that with a slightly modified benchmark.py
script:
INFO:root:Benchmark dilated attention...
INFO:root:Sequence length 4096: (5.918e-02 ± 2.241e-04) s
INFO:root:Sequence length 8192: (5.902e-02 ± 3.026e-05) s
INFO:root:Sequence length 16384: (5.900e-02 ± 3.374e-05) s
INFO:root:Sequence length 32768: (5.903e-02 ± 3.294e-05) s
INFO:root:Sequence length 65536: (5.905e-02 ± 3.805e-05) s
INFO:root:Sequence length 131072: (5.898e-02 ± 3.360e-05) s
INFO:root:Sequence length 262144: (5.897e-02 ± 2.324e-05) s
INFO:root:Sequence length 524288: (5.896e-02 ± 2.756e-05) s
^^ In that script, I dynamically choose the batch size, so that the total number of tokens is constant for all sequence lengths. So, it's roughly constant runtime when the forward/backward pass scales linearly with sequence length.
I haven't explicitly checked for memory profiling, but AFAIK it should scale the same as forward
as well.
Hi!
First of all, thanks for your great implementation. I think it is very awesome, I like it a lot. I was wondering if you have also implemented a backward pass for the model somewhere, since you have only shown the forward pass in this repo (please correct me if I am wrong). The reason why I am asking is because I want to train a reversible dilated Encoder model from scratch and your code seems very well suited for the attention mechanism.
Thanks in advance and kind regards!