fishjojo / pyscfad

PySCF with auto-differentiation
Other
63 stars 15 forks source link

Ensure real solution of logm for localized orbitals #44

Closed jonas-greiner closed 1 week ago

jonas-greiner commented 3 weeks ago

Hi Xing,

as you know scipy.linalg.logm will sometimes construct a complex matrix when applied to the orbital rotation matrix: https://github.com/fishjojo/pyscfad/blob/08b57daa3279f915591a89810b04535682aaba32/pyscfad/lo/boys.py#L137 I have written some code based on a real Schur decomposition which ensures that the calculated logarithm is real for all normal matrices:

t, q = scipy.linalg.schur(u)

norb = u.shape[0]
idx = 0
normalmatrix = True
while idx < norb:
    # final block reached for an odd number of orbitals
    if (idx == norb - 1):
        # single positive block
        if t[idx,idx] > 0.:
            t[idx,idx] = np.log(t[idx,idx])
        # single negative block (should not happen for normal matrix)
        else:
            normalmatrix = False
            break
    else:
        diag = np.isclose(t[idx,idx+1], 0.0) and np.isclose(t[idx+1,idx], 0.0)
        # single positive block
        if t[idx,idx] > 0. and diag:
            t[idx,idx] = np.log(t[idx,idx])
        # pair of two negative blocks
        elif t[idx,idx] < 0. and diag and np.isclose(t[idx,idx],t[idx + 1,idx + 1]):
            log_lambda = np.log(-t[idx,idx])
            t[idx:idx+2,idx:idx+2] = np.array(
                [[log_lambda, np.pi], [-np.pi, log_lambda]]
            )
            idx += 1
        # antisymmetric 2x2 block
        elif np.isclose(t[idx,idx], t[idx + 1,idx + 1]) and np.isclose(t[idx + 1,idx],-t[idx,idx + 1]):
            log_comp = np.log(complex(t[idx,idx], t[idx,idx + 1]))
            t[idx:idx+2,idx:idx+2] = np.array(
                [
                    [np.real(log_comp), np.imag(log_comp)], 
                    [-np.imag(log_comp), np.real(log_comp)],
                ],
            )
            idx += 1
        # should not happen for normal matrix
        else:
            normalmatrix = False 
            break
    idx += 1

if normalmatrix:
    kappa = q @ t @ q.T
else:
    logger.warn(
        loc,
        'No real matrix logarithm exists for the orbital rotation matrix.',
    )
    kappa = scipy.linalg.logm(u)`

For a normal matrix, the Schur-decomposed matrix is block-diagonal with blocks of size 1 and 2. The blocks of size 1 are either positive numbers or they come in pairs when they are negative while the blocks of size 2 always have the same value along the diagonal and values with different signs but the same magnitude on the off-diagonals. The block-diagonal matrix is similar to the orbital rotation matrix and the logarithm can therefore be calculated by determining the logarithm of the individual blocks which can be forced to be real and transforming using the Schur vectors.

From my initial tests, this appears to fix issues even when Jacobi sweeps are used to ensure that an extremum of the localization function has been found. I can create a PR if you agree that this should fix issues with differentiating calculations involving localized orbitals.

fishjojo commented 3 weeks ago

Thank you Jonas, and I'll take a look. I vaguely remember that scipy also use Schur decomposition to generate real solutions if possible, but I'll check.

jonas-greiner commented 1 week ago

As far as I can tell, the SciPy implementation only ensures that the principal logarithm is calculated if it exists (no eigenvalues that are real and negative). The orbital rotation matrix can have negative eigenvalues but these always come in pairs because the matrix is normal which ensures that a real matrix logarithm exists which SciPy does not appear to enforce. The above implementation is based on the this paper if you want to compare it to SciPy: morsyIJA1-4-2008.pdf

fishjojo commented 1 week ago

Thanks, @jonas-greiner. Do you have an example where your code gives different results than scipy?

jonas-greiner commented 1 week ago

Here you go:

import numpy as np

import jax
from jax import jacrev
from jax import numpy as jnp

from pyscfad import gto, scf
from pyscfad.lo import pipek

mol = gto.Mole()
mol.atom = """
    H
    F 1 0.91
"""
mol.basis = 'aug-pcseg-1'
mol.build(trace_coords=False, trace_exp=False, trace_ctr_coeff=False)

mf = scf.RHF(mol)
mf.kernel()
ao_dip = mol.intor_symmetric('int1e_r', comp=3)
h1 = mf.get_hcore()

def apply_E(E):
    mf.get_hcore = lambda *args, **kwargs: h1 + jnp.einsum('x,xij->ij', E, ao_dip)
    mf.kernel()
    return mf.dip_moment(mol, mf.make_rdm1(), unit='AU', verbose=0)

E0 = np.zeros((3))
polar = jax.jacrev(apply_E)(E0)
print(polar)

# finite difference polarizability
e1 = apply_E([ 0.0001, 0, 0])
e2 = apply_E([-0.0001, 0, 0])
print((e1 - e2) / 0.0002)

e1 = apply_E([0, 0.0001, 0])
e2 = apply_E([0,-0.0001, 0])
print((e1 - e2) / 0.0002)

e1 = apply_E([0, 0, 0.0001])
e2 = apply_E([0, 0,-0.0001])
print((e1 - e2) / 0.0002)

def apply_E_loc(E):
    mf.get_hcore = lambda *args, **kwargs: h1 + jnp.einsum('x,xij->ij', E, ao_dip)
    mf.kernel()
    mo_occ = mf.mo_occ[mf.mo_occ>0]
    orbocc = mf.mo_coeff[:, mf.mo_occ>0]
    orbloc = pipek.pm(mol, orbocc, init_guess="atomic")
    dm = (orbloc*mo_occ).dot(orbloc.conj().T)
    return mf.dip_moment(mol, dm, unit='AU', verbose=0)

E0 = np.zeros((3))
polar = jacrev(apply_E_loc)(E0)
print(polar)

Currently, pyscfad will crash complaining that 'Complex solutions are not supported for differentiating the Boys localization.' Replacing the logm call correctly reproduces the polarizabilities calculated from analytical and numerical differentiation of the dipole moments calculated from canonical orbitals.

fishjojo commented 1 week ago

@jonas-greiner, would you be willing to submit a pull request for this? It can be a custom logm in pyscfad/_src/scipy/linalg.py, such as

def logm(A, real=False, **kwargs):
    if real:
        return your_version(A)
    else:
        return scipy.linalg.logm(A, **kwargs)
jonas-greiner commented 1 week ago

Not sure if that is what you wanted since pyscfad/_src/scipy/linalg.py is otherwise empty since your last PR.