wwang2 / Coarse-Graining-Auto-encoders

40 stars 14 forks source link

Code for force matching term #2

Open yihengwuKP opened 1 year ago

yihengwuKP commented 1 year ago

Hi, it seems that I can't find the code for the average instantaneous mean force regularization term in the cgae.py or in the colab example:

        # recenter xyz 
        xyz = xyz[0].to(device)

        # encode and decode coordinates 
        xyz, xyz_recon, M, cg_xyz = ae(xyz, tau)

        # lift the cg_xyz back to the FG space 
        X_lift = torch.einsum('bij,ni->bnj', cg_xyz, M)

        # compute regularization to penalize atoms that are assigned too far away 
        loss_reg = (xyz - X_lift).pow(2).sum(-1).mean()

        # comput reconstruction 
        loss_recon = (xyz - xyz_recon).pow(2).mean() 

        # compute bond loss 
        bond_true = (xyz[:, bond_idx[:,0]] - xyz[:, bond_idx[:,1]] + 1e-9).pow(2).sum(-1).sqrt()
        bond_prd = (xyz_recon[:, bond_idx[:,0]] - xyz_recon[:, bond_idx[:,1]]+ 1e-9).pow(2).sum(-1).sqrt()
        loss_bond = (bond_true - bond_prd).pow(2).mean()

        # total loss 
        loss = loss_recon + 0.5 * loss_reg + loss_bond

Is it possible for you to share that part of the code at your convenience?