fishjojo / pyscfad

PySCF with auto-differentiation
Other
63 stars 15 forks source link

OO-CCD #20

Open zww-4855 opened 10 months ago

zww-4855 commented 10 months ago

Hi all,

I am trying to bootstrap the OO-MP2 example to perform OO-CCD. I have the following code, with only minimal changes to the original OO-MP2 code. It runs successfully for the first iteration where rotation x=0, but afterwards I get an error that is very verbose and difficult for me to interpret. Advice would be much appreciated.

from scipy.optimize import minimize
from jax import value_and_grad
from pyscf import numpy as np
from pyscfad import util
from pyscfad.tools import rotate_mo1
from pyscfad import gto, scf, mp
from pyscfad import cc

from pyscf.cc import ccsd as pyscf_cc
@util.pytree_node(['_scf'], num_args=1)
class OOCCD(cc.rccsd.RCCSD):
    def __init__(self, mf, x=None, **kwargs):
        #pyscf_cc.CCSD.__init__(self, mf)
        cc.rccsd.RCCSD.__init__(self,mf)
        self.x = x
        self.__dict__.update(kwargs)
        if self.x is None:
            nao = self.mol.nao
            assert nao == self.nmo
            size = nao*(nao+1)//2
            self.x = np.zeros([size,])
        self.mo_coeff = rotate_mo1(self.mo_coeff, self.x)
        self._scf.converged = False

mol = gto.Mole()
mol.atom ='H 0 0 0; F 0 0 1.1'
mol.basis = 'ccpvdz'
mol.build()
mf = scf.RHF(mol)
mf.kernel()

def func(x0, mf):
    def energy(x0, mf):
        mymp = OOCCD(mf, x=np.asarray(x0))
        old_update_amps = mymp.update_amps
        def update_amps(t1, t2, eris):
            t1, t2 = old_update_amps(t1, t2, eris)
            return t1*0, t2
        mymp.update_amps = update_amps
        mymp.kernel()
        return mymp.e_tot

    def grad(x0, mf):
        f, g = value_and_grad(energy)(x0, mf)
        return f, g

    f, g = grad(x0, mf)
    return (np.array(f), np.array(g))

nao = mol.nao
size = nao*(nao+1)//2
x0 = np.zeros([size,])
options = {"gtol":1e-5}
res = minimize(func, x0, args=(mf,), jac=True, method="BFGS", options = options)
e = func(res.x, mf)[0]
print(e)

Thanks, Zack

zww-4855 commented 10 months ago

I will add that I am suspicious of the validity of the following line:

@util.pytree_node(['_scf'], num_args=1)
fishjojo commented 10 months ago

One workaround is to comment out the following line https://github.com/fishjojo/pyscfad/blob/3a90edd2abacc2e50d5d9def9eaf74195f0e3284/pyscfad/cc/rccsd.py#L10 And the script below should work

from scipy.optimize import minimize
import numpy
import jax
from jax import numpy as jnp
from pyscfad import util
from pyscfad.tools import rotate_mo1
from pyscfad import gto, scf
from pyscfad import cc

@util.pytree_node(['_scf', 'x'], num_args=1)
class OOCCD(cc.rccsd.RCCSD):
    def __init__(self, mf, x, **kwargs):
        cc.rccsd.RCCSD.__init__(self, mf)
        self.x = x
        self.__dict__.update(kwargs)
        self.mo_coeff = rotate_mo1(self._scf.mo_coeff, self.x)
        self._scf.converged = False

    def update_amps(self, t1, t2, eris):
        _, t2 = cc.rccsd.update_amps(self, t1, t2, eris)
        return jnp.zeros_like(t1), t2

mol = gto.Mole()
mol.atom ='H 0 0 0; F 0 0 1.1'
mol.basis = 'ccpvdz'
mol.build()
mf = scf.RHF(mol)
mf.kernel()

nao = mol.nao
size = nao*(nao-1)//2
x0 = numpy.zeros([size,])

def energy(x0, mf):
    mymp = OOCCD(mf, x0)
    mymp.kernel()
    return mymp.e_tot

def jac(x0, mf):
    g = jax.grad(energy)(x0, mf)
    return g

options = {"gtol":1e-3}
res = minimize(energy, x0, args=(mf,), jac=jac, method="BFGS", options = options)
e = energy(res.x, mf)
print(e)

But this code will trace through the CC iterations, which is inefficient. Implicitly differentiating the CC iterations will be better, but it requires some work to implement for orbital optimizations. There is still a lot of work to do for improving the extensibility of the program, and any suggestions are welcome.

zww-4855 commented 9 months ago

Thanks @fishjojo for your prompt response. I neglected to add in the original ticket that I am running this inside your docker container, and after I make the changes you suggest above, I get a ndim error:

Traceback (most recent call last):
  File "oo_ccd.py", line 48, in <module>
    res = minimize(energy, x0, args=(mf,), jac=jac, method="BFGS", options = options)
  File "/usr/local/lib/python3.8/site-packages/scipy/optimize/_minimize.py", line 694, in minimize
    res = _minimize_bfgs(fun, x0, args, jac, callback, **options)
  File "/usr/local/lib/python3.8/site-packages/scipy/optimize/_optimize.py", line 1283, in _minimize_bfgs
    sf = _prepare_scalar_function(fun, x0, jac, args=args, epsilon=eps,
  File "/usr/local/lib/python3.8/site-packages/scipy/optimize/_optimize.py", line 263, in _prepare_scalar_function
    sf = ScalarFunction(fun, x0, args, grad, hess,
  File "/usr/local/lib/python3.8/site-packages/scipy/optimize/_differentiable_functions.py", line 158, in __init__
    self._update_fun()
  File "/usr/local/lib/python3.8/site-packages/scipy/optimize/_differentiable_functions.py", line 251, in _update_fun
    self._update_fun_impl()
  File "/usr/local/lib/python3.8/site-packages/scipy/optimize/_differentiable_functions.py", line 155, in update_fun
    self.f = fun_wrapped(self.x)
  File "/usr/local/lib/python3.8/site-packages/scipy/optimize/_differentiable_functions.py", line 137, in fun_wrapped
    fx = fun(np.copy(x), *args)
  File "oo_ccd.py", line 39, in energy
    mymp = OOCCD(mf, x0)
  File "oo_ccd.py", line 18, in __init__
    self.mo_coeff = rotate_mo1(self._scf.mo_coeff, self.x)
  File "/usr/local/lib/python3.8/site-packages/pyscfad/tools/util.py", line 20, in rotate_mo1
    mo_coeff1 = rotate_mo(mo_coeff, u)
  File "/usr/local/lib/python3.8/site-packages/pyscfad/tools/util.py", line 15, in rotate_mo
    mo = np.dot(mo_coeff, u)
  File "<__array_function__ internals>", line 180, in dot
ValueError: shapes (19,19) and (18,18) not aligned: 19 (dim 1) != 18 (dim 0)

I tried tracing the error into /tools/util.py where I see that the shape of dr inside update_rotate_matrix is (18,18). So I followed the unpack_triu function found in /lib/numpy_helper.py. Suffice this to say, I think this problem originates from these lines in your script:

size = nao*(nao-1)//2
x0 = numpy.zeros([size,])

I think you intend on x0 being the rotation matrix, which ultimately has to be of dimension of mf.mo_coeff; do I understand this correctly? For your program to run, based on the structure of /lib/numpy_helper.py in this particular example, the size of x0 should be 180.5; the int() call in line 28 conflicts with this though. I don't understand enough about this software to intelligently compensate for this. Can you advise @fishjojo ?

I appreciate your help thusfar.

Zack

fishjojo commented 9 months ago

The orbital rotation matrix u is equal to exp(x), where x is anti-Hermitian, so only the lower triangular part without the diagonal elements is independent. Thus x0 has the size of N(N-1)//2, where N is the number of all orbitals. Of course one could also only rotate between the occupied and virtual orbitals to further reduce the problem size.

zww-4855 commented 9 months ago

When you get a chance, can you see if you can take the code you posted above and run successfully in the docker container? I understand what you are saying. But what I am seeing from my end is that the dimension of mf.mo_coeff and exp(x) are not the same.