Closed lucidrains closed 3 years ago
I think the solution is quite simple - simply add this one line. Note that I've only considered this for tying MSA self attention.
if exists(tie_attn_dim):
q, k, v = map(lambda t: rearrange(t, '(b r) h n d -> b r h n d', r=tie_attn_dim), (q, k, v))
if exists(mask):
mask = rearrange(mask, '(b r) n -> b r n', r=tie_attn_dim)
q *= rearrange(mask, 'b r n -> b r () n ()') # THE NEW LINE
num_rows = (mask.sum(dim=-1) > 0).sum(dim=-1)
num_rows = rearrange(num_rows, 'b -> b () () ()')
mask = mask.sum(dim=1) > 0
@aced125 got it! thanks for letting me know that any(dim) works (i thought they didn't support axis-wise by their docs)
also fixed the zeroing of the contributions where rows are padding 🙏
…in a batch