Closed CiaoHe closed 3 years ago
Hi, Phil
I checked the code part here, but I found maybe it adds more than need.
Original: https://github.com/lucidrains/alphafold2/blob/a78352805204d47851f3b972e3d5c142ec7a3b49/alphafold2_pytorch/alphafold2.py#L437-L448
I thought it should be: m = msa_attn(m, mask = msa_mask, pairwise_repr = x) m = msa_ff(m) + m x = attn(x, mask = mask, msa_repr = m) x = ff(x) + x
m = msa_attn(m, mask = msa_mask, pairwise_repr = x)
m = msa_ff(m) + m
x = attn(x, mask = mask, msa_repr = m)
x = ff(x) + x
Since, in the attn_block, the return tensors have already been added with the residual part.
plz have a check
Best
@CiaoHe indeed! thank you! :pray:
Hi, Phil
I checked the code part here, but I found maybe it adds more than need.
Original: https://github.com/lucidrains/alphafold2/blob/a78352805204d47851f3b972e3d5c142ec7a3b49/alphafold2_pytorch/alphafold2.py#L437-L448
I thought it should be:
m = msa_attn(m, mask = msa_mask, pairwise_repr = x)
m = msa_ff(m) + m
x = attn(x, mask = mask, msa_repr = m)
x = ff(x) + x
Since, in the attn_block, the return tensors have already been added with the residual part.
plz have a check
Best