lucidrains / alphafold2

To eventually become an unofficial Pytorch implementation / replication of Alphafold2, as details of the architecture get released
MIT License
1.54k stars 256 forks source link

allow for tied row attention with uneven number of MSAs per sequence … #62

Closed lucidrains closed 3 years ago

lucidrains commented 3 years ago

…in a batch

aced125 commented 3 years ago

I think the solution is quite simple - simply add this one line. Note that I've only considered this for tying MSA self attention.

if exists(tie_attn_dim):
    q, k, v = map(lambda t: rearrange(t, '(b r) h n d -> b r h n d', r=tie_attn_dim), (q, k, v))
    if exists(mask):
        mask = rearrange(mask, '(b r) n -> b r n', r=tie_attn_dim)

        q *= rearrange(mask, 'b r n -> b r () n ()')  # THE NEW LINE

        num_rows = (mask.sum(dim=-1) > 0).sum(dim=-1)
        num_rows = rearrange(num_rows, 'b -> b () () ()')
        mask = mask.sum(dim=1) > 0
lucidrains commented 3 years ago

@aced125 got it! thanks for letting me know that any(dim) works (i thought they didn't support axis-wise by their docs)

also fixed the zeroing of the contributions where rows are padding 🙏