PattanaikL / GeoMol

MIT License
154 stars 43 forks source link

about the true angle computation #16

Open QiaolinLu opened 4 months ago

QiaolinLu commented 4 months ago

Hi, I have a question about why there is a [6, 6] matrix for true angle computation. For each central atom X, there are a maximum of 6 permutations, such as T1-X-T2, T1-X-T3, T2-X-T1, T2-X-T3, T3-X-T1, and T3-X-T2. I understand the significance of the first '6', but I'm unsure why there is a second '6'. Could you please explain? ` def ground_truth_local_stats(self, pos): """ Compute true one-hop, two-hop, and angle local stats. Note that the second dimension of the local coordinates is 6 to account for possible symmetric hydrogens. The max number of symmetric leaf hydrogens is 3, which leads to a max of 6 permutations (our model doesn't work for methane). This dimension captures these symmetric hydrogen permutations.

    :param pos: coordinates (n_atoms, n_true_confs, 3)
    :return: tuple of true stats (one-hop, two-hop, and angles)
        true_one_hop (n_neighborhoods, 6, 4, n_true_confs)
        true_two_hop (n_neighborhoods, 6, 4, 4, n_true_confs)
        true_angles (n_neighborhoods, 6, 6, n_true_confs)
    """

    n_neighborhoods = len(self.neighbors)
    self.true_local_coords = torch.zeros(n_neighborhoods, 6, 4, self.n_true_confs, 3).to(self.device)

    for i, (a, n) in enumerate(self.neighbors.items()):

        # permutations for symmetric hydrogens
        n_perms = n.unsqueeze(0).repeat(6, 1)
        perms = torch.tensor(list(permutations(n[self.leaf_hydrogens[a]]))).to(self.device)
        if perms.size(1) != 0:
            n_perms[0:len(perms), self.leaf_hydrogens[a]] = perms

        # keep it local
        self.true_local_coords[i, :, 0:len(n)] = pos[n_perms] - pos[a]

    # calculate true local stats
    true_one_hop, true_two_hop, true_angles = batch_local_stats_from_coords(self.true_local_coords, self.neighbor_mask)

    return true_one_hop, true_two_hop, true_angles`