XanaduAI / GradDFT

GradDFT is a JAX-based library enabling the differentiable design and experimentation of exchange-correlation functionals using machine learning techniques.
Apache License 2.0
79 stars 7 forks source link

Correctness checks #17

Closed jackbaker1001 closed 1 year ago

jackbaker1001 commented 1 year ago

Grad-DFT achieves something fairly complex. That along side the fact that (i) development occurred without any unit testing and (ii) it has not been exposed to a large number of users means that the probability that things (and perhaps even important things) are wrong is quite high.

In this issue, I will pass comments/questions about things I either don't understand or believe are incorrect. If we agree that incorrectness is present, this will be raised in separate issues and corrected.

jackbaker1001 commented 1 year ago

First red flag for me is the function nonXC function in molecule.py.

This function is used to calculate the terms of the DFT functional not involved in the XC term. It reads:

@partial(jax.jit, static_argnames=["precision"])
def nonXC(
    rdm1: Array,
    h1e: Array,
    rep_tensor: Array,
    nuclear_repulsion: Scalar,
    precision=Precision.HIGHEST,
) -> Scalar:
    r"""
    A function that computes the non-XC part of a DFT functional.

    Parameters
    ----------
    rdm1 : Array
        The 1-Reduced Density Matrix.
        Equivalent to mf.make_rdm1() in pyscf.
        Expected shape: (n_spin, n_orb, n_orb)
    h1e : Array
        The 1-electron Hamiltonian.
        Equivalent to mf.get_hcore(mf.mol) in pyscf.
        Expected shape: (n_orb, n_orb)
    rep_tensor : Array
        The repulsion tensor.
        Equivalent to mf.mol.intor('int2e') in pyscf.
        Expected shape: (n_orb, n_orb, n_orb, n_orb)
    nuclear_repulsion : Scalar
        Equivalent to mf.mol.energy_nuc() in pyscf.
        The nuclear repulsion energy.
    precision : Precision, optional
        The precision to use for the computation, by default Precision.HIGHEST

    Returns
    -------
    Scalar
        The non-XC energy of the DFT functional.
    """
    rdm1 = symmetrize_rdm1(rdm1)
    h1e_energy = one_body_energy(rdm1, h1e, precision)
    coulomb2e_energy = two_body_energy(rdm1, rep_tensor, precision)

    return nuclear_repulsion + h1e_energy + coulomb2e_energy

Some points from me here are:

  1. It is not clear why the nuclear repulsion energy is here. It does not involve electrons and is not a functional of their density. You would typically calculate this only once before the start of an SCF loop.

  2. This may not be an issue, but I hope h1e contains (i) the kinetic energy of KS electrons, (ii) The attraction with the nuclei and (iii) electron interaction the the density of all other electrons (Hartree term).

  3. There are no explicit 2 electron interactions in DFT (unless using hybrid functionals of course). It is therefore confusing that we have the two electron integrals from Hartree for theory here. I am fairly sure this is wrong unless this is mistakenly called a two electron interaction when really it is the Hartree energy which was not included in h1e.

jackbaker1001 commented 1 year ago

One major concern from me is in the prediction of energies after updating the parameters of a neural functional. Take the example in ~examples/example_neural_functional_03.py:

learning_rate = 1e-5
momentum = 0.9
tx = adam(learning_rate=learning_rate, b1=momentum)
opt_state = tx.init(params)

# and implement the optimization loop
n_epochs = 20
molecule_predict = molecule_predictor(neuralfunctional)
for iteration in tqdm(range(n_epochs), desc="Training epoch"):
    (cost_value, predicted_energy), grads = default_loss(
        params, molecule_predict, HH_molecule, ground_truth_energy
    )
    print("Iteration", iteration, "Predicted energy:", predicted_energy, "Cost value:", cost_value)
    updates, opt_state = tx.update(grads, opt_state, params)
    params = apply_updates(params, updates)

neuralfunctional.save_checkpoints(params, tx, step=n_epochs)

We are looking to minimize a mean square loss between the predicted energy and ground truth energy. Taking a look at how we calculate a predicted energy, we basically have:

densities = LSDA.compute_densities(molecule=HF_molecule)
# Then we compute the coefficient inputs
cinputs = LSDA.compute_coefficient_inputs(molecule=HF_molecule)
# Finally we compute the exchange-correlation energy
predicted_energy = LSDA.apply_and_integrate(params, HF_molecule.grid, cinputs, densities)

then we add on the other non XC components after which basically holds static the non XC-related energy functionals of the density while the XC part only is updated.

The problem is, after this update, we lose self consistency. It then seems to me that each predicted_energy call must be a self consistent calculation of the energy given the new parameters of the functional.

The above can be rationalized very simply. If I have completed a calculation with say the LDA and I use the output density as a starting point for say PBE, would you expect our calculation to be self consistent after only one iteration? You would not. That is the same as what is happening here: the neural functional with one set of parameters is a different functional to the neural functional with a different set of parameters. It could be true that we need only need to run the self consistent field loop every N steps say, but this is an approximation we need to set and test.

I can see that we have a self consistent field loop implemented in evaluate.py (I will be looking through this shortly) but I can not see any example where this has been used in training.

jackbaker1001 commented 1 year ago

Furthermore, adding the comments above about needing an SCF loop, this means we need to rethink how gradients are calculated. On the math side, what we have is essentially:

$$ E{KS}[\rho] (\gamma) = T[\rho] + E{ext} [\rho] + E{Ha}[\rho] + E{XC}[\rho] (\gamma) $$

Where we have:

$T[\rho]$: the total kinetic energy of the Kohn-Sham electrons.

$E_{ext} [\rho]$: The energy contribution of electrons interacting with an external potential (which is most of the time the interaction with point coulomb charges representing the nuclei).

$E_{Ha}[\rho]$: The electrostatic Hartree energy. Each electron interacts with the charge density as generated by all electrons including itself.

$E_{XC}[\rho] (\gamma)$: The exchange and correlation energy, parameterized by a vector of parameters $\gamma$ (the neural network paramers which in turn generate the coefficients, don't worry about this here though).

If we are working with just one step in an SCF cycle (I.e non self consistent) then one can see that the gradient $\partial E{KS}[\rho] (\gamma)/\partial \gamma$ depends only on the $E{XC}[\rho] (\gamma)$ term (other terms are clearly independent of $\gamma$). Because of this, we need not consider gradients flowing through any other terms in the Kohn-Sham Hamiltonian and you will see terms in the code like:

energy += stop_gradient(molecule.nonXC())

which stops the code differentiating through the non XC parts of the KS energy.

However, once self consistency is imposed, the density at SCF step $i$, $\rho_i$ for $i>1$ must depend on $\gamma$. I.e, the new density was calculated using the Kohn-sham eigenstates which themselves were found by solving the Kohn-sham equations given the functional parameters $\gamma$. Basically, for $i>1$, $\rho_i = \rhoi(\mathbf{r}, \gamma )$, which, upon substituting into $E{KS}[\rho] (\gamma)$ above, means that the entire Kohn-Sham energy now depends on the parameters of the functional. We therefore have no choice but to differentiate through the entire Kohn-Sham Hamiltonian which, to my knowledge, we are presently not doing.

PabloAMC commented 1 year ago

Addressing some of the issues:

It is not clear why the nuclear repulsion energy is here. It does not involve electrons and is not a functional of their density. You would typically calculate this only once before the start of an SCF loop.

Here I am replicating what PySCF was doing here https://github.com/pyscf/pyscf/blob/4c0acff4a74a9b25c16c4c0714b60d29876d51f5/pyscf/scf/hf.py#L300 In our case the nuclear repulsion is also computed and stored in Molecule, and not modified through the loop. That is, it is a property, not a method.

PabloAMC commented 1 year ago

With respect to

  1. This may not be an issue, but I hope h1e contains (i) the kinetic energy of KS electrons, (ii) The attraction with the nuclei and (iii) electron interaction the density of all other electrons (Hartree term).
  2. There are no explicit 2 electron interactions in DFT (unless using hybrid functionals of course). It is therefore confusing that we have the two electron integrals from Hartree for theory here. I am fairly sure this is wrong unless this is mistakenly called a two electron interaction when really it is the Hartree energy which was not included in h1e.

h1e contains nuclei-electron interaction and kinetic interaction. This is similar to https://github.com/pyscf/pyscf/blob/4c0acff4a74a9b25c16c4c0714b60d29876d51f5/pyscf/scf/hf.py#L305 and specifically lines https://github.com/pyscf/pyscf/blob/4c0acff4a74a9b25c16c4c0714b60d29876d51f5/pyscf/scf/hf.py#L316 and https://github.com/pyscf/pyscf/blob/4c0acff4a74a9b25c16c4c0714b60d29876d51f5/pyscf/scf/hf.py#L324 We do a similar thing in https://github.com/XanaduAI/DiffDFT/blob/bbd692db37890583e6ec3820bfd2d975c538857b/grad_dft/interface/pyscf.py#L562

The Coulomb interaction, even if it is not formally electron-electron, is treated computed separately in https://github.com/XanaduAI/DiffDFT/blob/bbd692db37890583e6ec3820bfd2d975c538857b/grad_dft/interface/pyscf.py#L595-L597 but more generally as in https://github.com/XanaduAI/DiffDFT/blob/bbd692db37890583e6ec3820bfd2d975c538857b/grad_dft/interface/pyscf.py#L599-L600 I think the generation of the repulsion tensor is unnecessary if we found a way to implement the Coulomb kernel correctly, but so far we did not really have much success. Thus, we are going for the slightly larger method of computing it via contraction with the repulsion tensor, as done in the HF tutorial in pennylane https://github.com/PennyLaneAI/pennylane/blob/e93fc1b58ae325cf38a9e3296ae9e340e6ac0fa0/pennylane/qchem/hartree_fock.py#L149

PabloAMC commented 1 year ago

In comments https://github.com/XanaduAI/DiffDFT/issues/17#issuecomment-1677944006 and https://github.com/XanaduAI/DiffDFT/issues/17#issuecomment-1678032962 you are completely right in that the ideal way to do things would be to differentiate through the scf loop. That way we would actually keep self-consistency. Further, you are also right in that probably the stop_gradients in the non XC part should probably not be there. I will test and remove it. However, the main issue is that JAX returns Nans if we try to differentiate through the scf loop, and I have been unable to make it work. This was in fact the main motivation for the jitted scf loop.

jackbaker1001 commented 1 year ago

First red flag for me is the function nonXC function in molecule.py.

This function is used to calculate the terms of the DFT functional not involved in the XC term. It reads:

@partial(jax.jit, static_argnames=["precision"])
def nonXC(
    rdm1: Array,
    h1e: Array,
    rep_tensor: Array,
    nuclear_repulsion: Scalar,
    precision=Precision.HIGHEST,
) -> Scalar:
    r"""
    A function that computes the non-XC part of a DFT functional.

    Parameters
    ----------
    rdm1 : Array
        The 1-Reduced Density Matrix.
        Equivalent to mf.make_rdm1() in pyscf.
        Expected shape: (n_spin, n_orb, n_orb)
    h1e : Array
        The 1-electron Hamiltonian.
        Equivalent to mf.get_hcore(mf.mol) in pyscf.
        Expected shape: (n_orb, n_orb)
    rep_tensor : Array
        The repulsion tensor.
        Equivalent to mf.mol.intor('int2e') in pyscf.
        Expected shape: (n_orb, n_orb, n_orb, n_orb)
    nuclear_repulsion : Scalar
        Equivalent to mf.mol.energy_nuc() in pyscf.
        The nuclear repulsion energy.
    precision : Precision, optional
        The precision to use for the computation, by default Precision.HIGHEST

    Returns
    -------
    Scalar
        The non-XC energy of the DFT functional.
    """
    rdm1 = symmetrize_rdm1(rdm1)
    h1e_energy = one_body_energy(rdm1, h1e, precision)
    coulomb2e_energy = two_body_energy(rdm1, rep_tensor, precision)

    return nuclear_repulsion + h1e_energy + coulomb2e_energy

Some points from me here are:

  1. It is not clear why the nuclear repulsion energy is here. It does not involve electrons and is not a functional of their density. You would typically calculate this only once before the start of an SCF loop.
  2. This may not be an issue, but I hope h1e contains (i) the kinetic energy of KS electrons, (ii) The attraction with the nuclei and (iii) electron interaction the the density of all other electrons (Hartree term).
  3. There are no explicit 2 electron interactions in DFT (unless using hybrid functionals of course). It is therefore confusing that we have the two electron integrals from Hartree for theory here. I am fairly sure this is wrong unless this is mistakenly called a two electron interaction when really it is the Hartree energy which was not included in h1e.

This was resolved. The non XC energy was tested and found to be correct. There are just naming misnomers in PySCF.

PabloAMC commented 1 year ago

We may want to close the issue, if checks work just fine.

jackbaker1001 commented 1 year ago

After checking the energy obtained by our self-consistent methods vs PySCF, I will close this.

jackbaker1001 commented 1 year ago

Ok. The integrations tests test the SCF results versus PySCF so I think we can close.