fkodom / dilated-attention-pytorch

(Unofficial) Implementation of dilated attention from "LongNet: Scaling Transformers to 1,000,000,000 Tokens" (https://arxiv.org/abs/2307.02486)
MIT License
50 stars 9 forks source link

Q: Attention Calculation #5

Open mohamedelbahnasawi opened 1 year ago

mohamedelbahnasawi commented 1 year ago

Hi @fkodom,

I really like your implementation and I wanted to use dilated attention into a vanilla transformer model to try how things work.

Right now, I am facing a problem during the attention calculation in which you use flash attention because they do not include a way to provide padding mask. For the scaled dot product, I am not sure If the masks also should be segmented and sparsified. Do you have an idea how calculated the attention using the scaled dot product taking into consideration the padding mask for encoder mask and padding and causal mask together for the decoder mask?

Thanks for help!

Coluding commented 11 months ago

Hi @mohamedelbahnasawi,

have you made any progress on this task? I am planning to implement a reversible dilated Encoder model. Let me know if you wanna collaborate!

Kind regards!

mohamedelbahnasawi commented 11 months ago

Hi @Coluding,

Unfortunately I stopped till I find a solution as I tried several times but there was always some problems with the masking. For sure, It would be nice if we can collaborate on this implementation and try to solve this problem.

I also hope if @fkodom can give us some tips on how to solve the problem with scaled dot product instead of flash attention.

Best regards, Mohamed

fkodom commented 11 months ago

@mohamedelbahnasawi Are you trying to build an encoder-only model (e.g. BERT)? I don't have an immediate solution for the padding masks in that case, but I can look into it further. For causal decoder-only models (e.g. GPT), it likely doesn't matter -- you only need to apply masking to the loss function.

mohamedelbahnasawi commented 11 months ago

Hi @fkodom,

Thank you for replying, I am actually trying to build an Encoder-Decoder model just like the vanilla transformer architecture but with dilated attention.

fkodom commented 11 months ago

@mohamedelbahnasawi Got it -- I'll take a look. I believe the fix should be here. xops also allows you to pass a Tensor mask, in place of the LowerTriangularMask I was lazily using. So we'll need to add an attn_mask: Optional[Tensor] = None argument to forward() for each module.

Peeking at the xformers docstring:

Screen Shot 2023-11-29 at 9 47 52 AM

So it may be slower, but more helpful for your use case.