lucidrains / local-attention

An implementation of local windowed attention for language modeling
MIT License
383 stars 40 forks source link

Bug in exact_window_size masking for Causal Attention? #5

Closed xravitejax closed 3 years ago

xravitejax commented 3 years ago

https://github.com/lucidrains/local-attention/blob/aa7e1b2c94b9f9705f61125ef518956b5621296c/local_attention/local_attention.py#L150

Shouldn't we do an 'or' operation between causal_mask and exact_window_size_mask, instead of an 'and' ?

lucidrains commented 3 years ago

@xravitejax Hello Ravi! I think it is correct, because what we are trying to do is mask out attention to any keys that is greater than max_causal_window_size into the past

xravitejax commented 3 years ago

Let's say look_backward = 1, window size = 4 and we have are currently processing the first token of the last window.

Causal mask would look similar to this? since size we don't want to look in to the future. [False, False, False, False, False, True, True, True]

Exact Window Size mask would look similar to this ?
[ True, False, False, False, False, False, False, False]

Causal mask & Exact Window Size mask [False, False, False, False, False, False, False, False]

Causal mask | Exact Window Size mask [True, False, False, False, False, True, True, True]

Shouldn't we use the or operation where we don't mask only the last four tokens(including input) ? I hope I remember the code correctly.

lucidrains commented 3 years ago

@xravitejax oh crap, you are correct, thank you! https://github.com/lucidrains/local-attention/releases/tag/1.2.2