Closed xravitejax closed 3 years ago
@xravitejax Hello Ravi! I think it is correct, because what we are trying to do is mask out attention to any keys that is greater than max_causal_window_size
into the past
Let's say look_backward = 1, window size = 4 and we have are currently processing the first token of the last window.
Causal mask would look similar to this? since size we don't want to look in to the future. [False, False, False, False, False, True, True, True]
Exact Window Size mask would look similar to this ?
[ True, False, False, False, False, False, False, False]
Causal mask & Exact Window Size mask [False, False, False, False, False, False, False, False]
Causal mask | Exact Window Size mask [True, False, False, False, False, True, True, True]
Shouldn't we use the or operation where we don't mask only the last four tokens(including input) ? I hope I remember the code correctly.
@xravitejax oh crap, you are correct, thank you! https://github.com/lucidrains/local-attention/releases/tag/1.2.2
https://github.com/lucidrains/local-attention/blob/aa7e1b2c94b9f9705f61125ef518956b5621296c/local_attention/local_attention.py#L150
Shouldn't we do an 'or' operation between causal_mask and exact_window_size_mask, instead of an 'and' ?