uw-ipd / RoseTTAFold2

MIT License
160 stars 36 forks source link

Is there a redundant addition of positional encoding in the IterBlock section? #28

Open ackbar03 opened 7 months ago

ackbar03 commented 7 months ago

Hi,

I see that an additional positional encoding is added in the forward part of the IterBlock module.

https://github.com/uw-ipd/RoseTTAFold2/blob/main/network/Track_module.py#L544C1-L550C47

def forward(self, msa, pair, R_in, T_in, xyz, state, idx, symmids, symmsub_in, symmsub, symmRs, symmmeta, use_checkpoint=False, topk=0, crop=-1):
        #rbf_feat = rbf(torch.cdist(xyz[:,:,1,:], xyz[:,:,1,:])) + self.pos(idx)
        O,L = pair.shape[:2]
        xyzfull = xyz.view(1,O*L,3,3)
        rbf_feat = rbf(
            torch.cdist(xyzfull[:,:,1,:], xyzfull[:,:L,1,:])
        ).reshape(O,L,L,-1) + self.pos(idx, O)

This seems redundant since rbf_feat is then added to the variable pair, which already contains the positional encoding from the MSA_emb module.

I know that the model is already trained with this in place, so it probably doesn't matter, but I just wanted to check if my interpretation is correct? I also noticed that the corresponding section for this in RFDiffusion has been modified so there is no additional positional encoding added.

Thanks