pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.36k stars 3.66k forks source link

math term missed in dimenet_utils.associated_legendre_polynomials #8160

Closed PoloWitty closed 1 year ago

PoloWitty commented 1 year ago

🛠 Proposed Refactor

In order to get the function of 'spherical harmonics' basis used in DimeNet, we need to calc 'Associated Legendre Polynomial' in math. But I just found a term missed in torch_geometric.nn.models.dimenet_utils.associated_legendre_polynomials().

The author of DimeNet use zero_m_only=True by default, so there will be no error when running the code by default.

But just a warning for those who also read or use this code, the author of the code (version of pyg: 2.3.1) missed a term when calculating $P_l^l$ which make all results except $P_l^0$ wrong. original code here

Take the result of $P_1^1$ as an example:

original code will output -1, but according to eq.14 at https://mathworld.wolfram.com/AssociatedLegendrePolynomial.html, the ground truth result is $-(1-x^2)^{0.5}$.

Suggest a potential alternative/fix

Beyond adding the missed term, I also add some comment to help understanding.

def associated_legendre_polynomials(k, zero_m_only=True):
    '''
    helper function to calc Y_l^m
    '''
    z = sym.symbols('z')
    P_l_m = [[0] * (j + 1) for j in range(k)]

    P_l_m[0][0] = 1
    if k > 0:
        P_l_m[1][0] = z

        for j in range(2, k):
            P_l_m[j][0] = sym.simplify(((2 * j - 1) * z * P_l_m[j - 1][0] -
                                        (j - 1) * P_l_m[j - 2][0]) / j) # use the property of eq.7: https://mathworld.wolfram.com/AssociatedLegendrePolynomial.html
        if not zero_m_only:
            for i in range(1, k):
                P_l_m[i][i] = sym.simplify((1 - 2 * i) * P_l_m[i - 1][i - 1]*(1-z**2)**0.5) # add missed term (*(1-z**2)**0.5) here
                if i + 1 < k:
                    P_l_m[i + 1][i] = sym.simplify(
                        (2 * i + 1) * z * P_l_m[i][i]) # use the property of eq.11: https://mathworld.wolfram.com/AssociatedLegendrePolynomial.html
                for j in range(i + 2, k):
                    P_l_m[j][i] = sym.simplify(
                        ((2 * j - 1) * z * P_l_m[j - 1][i] -
                         (i + j - 1) * P_l_m[j - 2][i]) / (j - i)) # use the property of eq.7: https://mathworld.wolfram.com/AssociatedLegendrePolynomial.html

    return P_l_m

After fixing this tiny bug, all the results are the right

rusty1s commented 1 year ago

Thanks for letting us know. This code was actually directly taken from github.com/klicperajo/dimenet. Will be fixed in https://github.com/pyg-team/pytorch_geometric/pull/8164.