Closed usamec closed 3 years ago
exact_mask = torch.abs(bq_t[:,:,:,None] - bq_k[:,:,None,:]) > self.window_size dots.masked_fill_(exact_mask, mask_value)
This could be hidden under a flag similarly to shared_qk.
@usamec Hi Vlado! this is a great suggestion! building now
https://github.com/lucidrains/local-attention/releases/tag/1.2.0
This could be hidden under a flag similarly to shared_qk.