lucidrains / local-attention

An implementation of local windowed attention for language modeling
MIT License
386 stars 40 forks source link

More control over attention masking #9

Open Mindful opened 3 years ago

Mindful commented 3 years ago

First of all, thanks for creating this - as far as I can tell it's the only Pytorch local attention implementation that isn't a one-off created as part of a specific model.

My issue is that I'm using local attention in a model where I'd like to do XLNet style pre-training - I.E. masking random characters when predicting a specific character. It seems like the attention mask is intended to be of shape [batch_size, sequence_length] and used for excluding padding, while masking for autoregressive decoding is handled by the causal flag. I'm wondering if it would be possible to support arbitrary attention masks per token the way Pytorch's Transformer model does with src_mask, which has shape [sequence_length, sequence_length]. Alternatively, if there's a way to do this with the current implementation and I've just missed it, please let me know.

TLDR; as far as I can tell, the current attention mask parameter is equivalent to the Pytorch Transformer src_key_padding_mask. Would it be possible to get one that was equivalent to src_mask?

This is the kind of thing I'd like to just open a PR for, but in this case I'm not confident I know how best to implement this.

Mindful commented 3 years ago

So I got a hold of the authors for the paper I was trying to replicate and they didn't actually do attention masking the same way XLNet does, which means I actually don't need this (sorry!). Feel free to close the issue if you'd like; I'll leave it open for now just in case someone else ends up needing it.