valence-labs / mess

MESS: Modern Electronic Structure Simulations
MIT License
20 stars 2 forks source link

Add SCAN functional #21

Open ESEberhard opened 3 weeks ago

ESEberhard commented 3 weeks ago

Hi guys,

I thought you might find our implementation of SCAN for EG-XC useful to integrate into the code base. I attached it below :)

This module contains the implementation of the SCAN (Strongly Constrained and
Appropriately Normed Semilocal) meta-GGA functional by Sun et al.

The SCAN functional is a meta-GGA functional, which means that it depends on:
    r_s = (3 / (4 * pi * n))**(1 / 3)
the spin polarization:
    xi  = (n_up - n_down) / n
the reduced gradient:
    s = |grad(n)| / (2 * (3 * pi**2)**(1 / 3) * n**(4 / 3))
import jax
import jax.numpy as jnp
from functools import partial

from mldft.xc_energy.ueg_xc import e_x_ueg, ec_LSDA1_fn

pi = jnp.pi

def e_x_ueg(n: jax.Array) -> jax.Array:
    The exchange energy per particle of the uniform electron gas.
    return - (3 / 4) * (3 / pi)**(1 / 3) * n**(1 / 3)

def ec_LSDA1_fn(n: jax.Array,
                xi: jax.Array,
                modified=False) -> jax.Array:
    The correlation energy of the uniform electron gas by
    Perdew and Wang (1992) (PW92).,

    libxc reference implementation:
    r_s = _calc_r_s(n)

    def analytic_base_from(p, A, a1, b1, b2, b3, b4) -> jax.Array:
        p and A are constrained the remaining parameters were fitted (see ref)
        beta_sum = b1 * r_s**(1 / 2) \
                 + b2 * r_s          \
                 + b3 * r_s**(3 / 2) \
                 + b4 * r_s**(p + 1)
        log_term = jnp.log1p(1 / (2 * A * beta_sum))
        return - 2 * A * (1 + a1 * r_s) * log_term

    A_unpolarized = 0.031091 if not modified else 0.0310907
    A_polarized   = 0.015545 if not modified else 0.01554535
    A_alpha_c     = 0.016887 if not modified else 0.0168869
    if use_RPA:
        # random-phase approximation (RPA)
        ec_unpolarized =analytic_base_from(p=0.75, A=A_unpolarized,
                            a1=0.082477, b1=5.1486, b2=1.6483, b3=0.2347,  b4=0.20614)
        ec_polarized =  analytic_base_from(p=0.75, A=A_polarized,
                            a1=0.035374, b1=6.4869, b2=1.3083, b3=0.15180, b4=0.082349)
        alpha_c =     - analytic_base_from(p=1,    A=A_alpha_c,
                            a1=0.028829, b1=10.357, b2=3.6231, b3=0.47990, b4=0.12279)
        ec_unpolarized =analytic_base_from(p=1, A=A_unpolarized,
                            a1=0.21370, b1=7.5957,  b2=3.5876, b3=1.6382,  b4=0.49294)
        ec_polarized =  analytic_base_from(p=1, A=A_polarized,
                            a1=0.20548, b1=14.1189, b2=6.1977, b3=3.3662,  b4=0.62517)
        alpha_c =     - analytic_base_from(p=1, A=A_alpha_c,
                            a1=0.11125, b1=10.357,  b2=3.6231, b3=0.88026, b4=0.49671)

    dd_f_zero = 1.709921 if not modified else 1.709920934161365617563962776245
    # alpha_c = dd_f_zero * (ec_polarized - ec_unpolarized)  # from another reference
    f_xi = ((1 + xi)**(4 / 3) + (1 - xi)**(4 / 3) - 2) / (2**(4 / 3) - 2)

    return ec_unpolarized + alpha_c * f_xi / dd_f_zero * (1 - xi**4) \
          + (ec_polarized - ec_unpolarized) * f_xi * xi**4

def _calc_r_s(n: jax.Array, epsilon: float = 0) -> jax.Array:
    The Wigner-Seitz radius
    return (3 / (4 * pi * n + epsilon))**(1 / 3)

def _f_interp(alpha: jax.Array, c1: float, c2: float, d: float) -> jax.Array:
    TODO: add nondiff_argnums=(1,2,3)
    term1 = jnp.exp(-c1 * alpha / (1 - alpha))
    term2 = -d * jnp.exp(c2 / (1 - alpha))
    return jnp.where(alpha < 1, term1, term2)

def jvp_f_interp(primals, tangents):
    does not account for derivative w.r.t. constants c1, c2, d
    alpha, c1, c2, d = primals
    alpha_dot, _, _, _ = tangents
    df = _f_interp(alpha, c1, c2, d)
    dterm1_factor = -c1 / (1 - alpha)**2
    dterm2_factor =  c2 / (1 - alpha)**2
    df_dot = jnp.where(alpha < 1, dterm1_factor, dterm2_factor) * df * alpha_dot
    return df, df_dot

def e_x_scan(n: jax.Array, s:jax.Array, alpha: jax.Array) -> jax.Array:
    The exchange energy per particle of SCAN

    def F_x(s: jax.Array, alpha: jax.Array) -> jax.Array:
        The exchange enhancement factor
        # fit parameters
        k1 = 0.065
        c1x = 0.667
        c2x = 0.8
        dx = 1.24

        def h1x_fn(s: jax.Array, alpha: jax.Array) -> jax.Array:
            mu_ak = 10 / 81
            b2 = jnp.sqrt(5913 / 405000)
            b1 = (511 / 13500) / (2 * b2)
            b3 = 0.5
            b4 = mu_ak**2 / k1 - 1606 / 18225 - b1**2

            exp1 = jnp.exp(- jnp.abs(b4) * s**2 / mu_ak)
            exp2 = jnp.exp(- b3 * (1  - alpha)**2)
            x = mu_ak * s**2 * (1 + (b4 * s**2 / mu_ak) * exp1) \
                + (b1 * s**2 + b2 * (1 - alpha) * exp2)**2
            return 1 + k1 - k1 / (1 + x / k1)

        a1 = 4.9479
        h0x = 1.174

        def gx(s: jax.Array) -> jax.Array:
            return - jnp.expm1(-a1 / jnp.sqrt(s))

        h1x = h1x_fn(s, alpha)
        fx_alpha = _f_interp(alpha, c1x, c2x, dx)
        return (h1x + fx_alpha * (h0x - h1x)) * gx(s)

    return e_x_ueg(n) * F_x(s, alpha)

@jax.jit  # TODO: add static_argnames=("xi") ?
def e_c_scan(n: jax.Array,
             xi: jax.Array,
             alpha: jax.Array) -> jax.Array:
    The correlation energy per particle of SCAN
    TODO: verify xi != 0 correctness if polarized systems are added

    # fit parameters
    c1c = 0.64
    c2c = 1.5
    dc = 0.7

    #fixed constants
    b1c = 0.0285764
    b2c = 0.0889
    b3c = 0.125541

    def Psi_fn(xi: jax.Array) -> jax.Array:
        return ((1 + xi) ** (2 / 3) + (1 - xi) ** (2 / 3)) / 2

    def ec1_fn(n: jax.Array, s:jax.Array, xi: jax.Array) -> jax.Array:
        Perdew-Ernzerhof-Wang 1996 (PEW96)-like correlation energy
        ec_LSDA1 = ec_LSDA1_fn(n, xi)

        def H1(r_s: jax.Array, s:jax.Array, xi: jax.Array) -> jax.Array:
            # gamma = 0.031091
            gamma = (1 - jnp.log(2)) / jnp.pi**2
            # beta = 0.06672455060314922
            beta = 0.066725 * (1 + 0.1 * r_s) / (1 + 0.1778 * r_s)  # SCAN
            # beta = 0.066725  # PBE0
            Psi = Psi_fn(xi)
            w1 = jnp.expm1(- ec_LSDA1 / (gamma * Psi**3))
            A = beta / (gamma * w1)
            t = ((3 * pi**2 / 16) ** (1 / 3) * s) / (Psi * jnp.sqrt(r_s))
            g = (1 + 4 * A * t**2)**(- 1 / 4)
            # g = 1 / (1 + A * t**2 + A**2 * t**4)  # PBE0
            return gamma * Psi**3 * jnp.log1p(w1 * (1 - g))

        return ec_LSDA1 + H1(r_s, s, xi)

    def ec0_fn(r_s: jax.Array, s:jax.Array, xi: jax.Array) -> jax.Array:
        ec_LDA0 = - b1c / (1 + b2c * jnp.sqrt(r_s) + b3c * r_s)

        def dx(xi: jax.Array) -> jax.Array:
            return ((1 + xi)**(4 / 3) + (1 - xi)**(4 / 3)) / 2

        def Gc(xi: jax.Array) -> jax.Array:
            return (1 - 2.3631 * (dx(xi) - 1)) * (1 - xi**12)

        w0 = jnp.expm1(- ec_LDA0 / b1c)
        chi_unpolarized = 0.12802585262625815 # 0.128026
        g = 1 / (1 + 4 * chi_unpolarized * s**2)**(1 / 4)
        H0 = b1c * jnp.log1p(w0 * (1 - g))
        return (ec_LDA0 + H0) * Gc(xi)

    r_s = _calc_r_s(n)
    ec1 = ec1_fn(n, s, xi)
    fc_alpha = _f_interp(alpha, c1c, c2c, dc)
    return ec1 + fc_alpha * (ec0_fn(r_s, s, xi) - ec1)

@partial(jax.jit, static_argnames=("xi"))
def e_xc_scan(n: jax.Array,
              xi: jax.Array,
              alpha: jax.Array) -> jax.Array:
    The exchange-correlation energy per particle of SCAN
    return e_x_scan(n, s, alpha) + e_c_scan(n, s, xi, alpha)
hatemhelal commented 3 weeks ago

Thank you @ESEberhard, SCAN is a functional I have wanted to learn more about and integrate into MESS so this is a great help 🤗