Open yihengwuKP opened 1 year ago
Hi, it seems that I can't find the code for the average instantaneous mean force regularization term in the cgae.py or in the colab example:
cgae.py
# recenter xyz xyz = xyz[0].to(device) # encode and decode coordinates xyz, xyz_recon, M, cg_xyz = ae(xyz, tau) # lift the cg_xyz back to the FG space X_lift = torch.einsum('bij,ni->bnj', cg_xyz, M) # compute regularization to penalize atoms that are assigned too far away loss_reg = (xyz - X_lift).pow(2).sum(-1).mean() # comput reconstruction loss_recon = (xyz - xyz_recon).pow(2).mean() # compute bond loss bond_true = (xyz[:, bond_idx[:,0]] - xyz[:, bond_idx[:,1]] + 1e-9).pow(2).sum(-1).sqrt() bond_prd = (xyz_recon[:, bond_idx[:,0]] - xyz_recon[:, bond_idx[:,1]]+ 1e-9).pow(2).sum(-1).sqrt() loss_bond = (bond_true - bond_prd).pow(2).mean() # total loss loss = loss_recon + 0.5 * loss_reg + loss_bond
Is it possible for you to share that part of the code at your convenience?
Hi, it seems that I can't find the code for the average instantaneous mean force regularization term in the
cgae.py
or in the colab example:Is it possible for you to share that part of the code at your convenience?