Open ESEberhard opened 3 weeks ago
Hi guys,
I thought you might find our implementation of SCAN for EG-XC useful to integrate into the code base. I attached it below :)
""" This module contains the implementation of the SCAN (Strongly Constrained and Appropriately Normed Semilocal) meta-GGA functional by Sun et al. https://doi.org/10.1103/PhysRevLett.115.036402 https://journals.aps.org/prl/supplemental/10.1103/PhysRevLett.115.036402 The SCAN functional is a meta-GGA functional, which means that it depends on: r_s = (3 / (4 * pi * n))**(1 / 3) the spin polarization: xi = (n_up - n_down) / n the reduced gradient: s = |grad(n)| / (2 * (3 * pi**2)**(1 / 3) * n**(4 / 3)) """ import jax import jax.numpy as jnp from functools import partial from mldft.xc_energy.ueg_xc import e_x_ueg, ec_LSDA1_fn pi = jnp.pi def e_x_ueg(n: jax.Array) -> jax.Array: """ The exchange energy per particle of the uniform electron gas. """ return - (3 / 4) * (3 / pi)**(1 / 3) * n**(1 / 3) def ec_LSDA1_fn(n: jax.Array, xi: jax.Array, use_RPA=False, modified=False) -> jax.Array: """ The correlation energy of the uniform electron gas by Perdew and Wang (1992) (PW92). https://doi.org/10.1103/PhysRevB.45.13244, libxc reference implementation: https://github.com/ElectronicStructureLibrary/libxc/blob/master/src/lda_c_pw.c https://github.com/ElectronicStructureLibrary/libxc/blob/master/src/maple2c/lda_exc/lda_c_pw.c#L14 https://github.com/ElectronicStructureLibrary/libxc/blob/master/maple/lda_exc/lda_c_pw.mpl https://github.com/ElectronicStructureLibrary/libxc/blob/master/maple/util.mpl """ r_s = _calc_r_s(n) def analytic_base_from(p, A, a1, b1, b2, b3, b4) -> jax.Array: """ p and A are constrained the remaining parameters were fitted (see ref) """ beta_sum = b1 * r_s**(1 / 2) \ + b2 * r_s \ + b3 * r_s**(3 / 2) \ + b4 * r_s**(p + 1) log_term = jnp.log1p(1 / (2 * A * beta_sum)) return - 2 * A * (1 + a1 * r_s) * log_term A_unpolarized = 0.031091 if not modified else 0.0310907 A_polarized = 0.015545 if not modified else 0.01554535 A_alpha_c = 0.016887 if not modified else 0.0168869 if use_RPA: # random-phase approximation (RPA) ec_unpolarized =analytic_base_from(p=0.75, A=A_unpolarized, a1=0.082477, b1=5.1486, b2=1.6483, b3=0.2347, b4=0.20614) ec_polarized = analytic_base_from(p=0.75, A=A_polarized, a1=0.035374, b1=6.4869, b2=1.3083, b3=0.15180, b4=0.082349) alpha_c = - analytic_base_from(p=1, A=A_alpha_c, a1=0.028829, b1=10.357, b2=3.6231, b3=0.47990, b4=0.12279) else: ec_unpolarized =analytic_base_from(p=1, A=A_unpolarized, a1=0.21370, b1=7.5957, b2=3.5876, b3=1.6382, b4=0.49294) ec_polarized = analytic_base_from(p=1, A=A_polarized, a1=0.20548, b1=14.1189, b2=6.1977, b3=3.3662, b4=0.62517) alpha_c = - analytic_base_from(p=1, A=A_alpha_c, a1=0.11125, b1=10.357, b2=3.6231, b3=0.88026, b4=0.49671) dd_f_zero = 1.709921 if not modified else 1.709920934161365617563962776245 # alpha_c = dd_f_zero * (ec_polarized - ec_unpolarized) # from another reference f_xi = ((1 + xi)**(4 / 3) + (1 - xi)**(4 / 3) - 2) / (2**(4 / 3) - 2) return ec_unpolarized + alpha_c * f_xi / dd_f_zero * (1 - xi**4) \ + (ec_polarized - ec_unpolarized) * f_xi * xi**4 def _calc_r_s(n: jax.Array, epsilon: float = 0) -> jax.Array: """ The Wigner-Seitz radius """ return (3 / (4 * pi * n + epsilon))**(1 / 3) @jax.custom_jvp def _f_interp(alpha: jax.Array, c1: float, c2: float, d: float) -> jax.Array: """ TODO: add nondiff_argnums=(1,2,3) """ term1 = jnp.exp(-c1 * alpha / (1 - alpha)) term2 = -d * jnp.exp(c2 / (1 - alpha)) return jnp.where(alpha < 1, term1, term2) @_f_interp.defjvp def jvp_f_interp(primals, tangents): """ does not account for derivative w.r.t. constants c1, c2, d """ alpha, c1, c2, d = primals alpha_dot, _, _, _ = tangents df = _f_interp(alpha, c1, c2, d) dterm1_factor = -c1 / (1 - alpha)**2 dterm2_factor = c2 / (1 - alpha)**2 df_dot = jnp.where(alpha < 1, dterm1_factor, dterm2_factor) * df * alpha_dot return df, df_dot @jax.jit def e_x_scan(n: jax.Array, s:jax.Array, alpha: jax.Array) -> jax.Array: """ The exchange energy per particle of SCAN """ def F_x(s: jax.Array, alpha: jax.Array) -> jax.Array: """ The exchange enhancement factor """ # fit parameters k1 = 0.065 c1x = 0.667 c2x = 0.8 dx = 1.24 def h1x_fn(s: jax.Array, alpha: jax.Array) -> jax.Array: mu_ak = 10 / 81 b2 = jnp.sqrt(5913 / 405000) b1 = (511 / 13500) / (2 * b2) b3 = 0.5 b4 = mu_ak**2 / k1 - 1606 / 18225 - b1**2 exp1 = jnp.exp(- jnp.abs(b4) * s**2 / mu_ak) exp2 = jnp.exp(- b3 * (1 - alpha)**2) x = mu_ak * s**2 * (1 + (b4 * s**2 / mu_ak) * exp1) \ + (b1 * s**2 + b2 * (1 - alpha) * exp2)**2 return 1 + k1 - k1 / (1 + x / k1) a1 = 4.9479 h0x = 1.174 def gx(s: jax.Array) -> jax.Array: return - jnp.expm1(-a1 / jnp.sqrt(s)) h1x = h1x_fn(s, alpha) fx_alpha = _f_interp(alpha, c1x, c2x, dx) return (h1x + fx_alpha * (h0x - h1x)) * gx(s) return e_x_ueg(n) * F_x(s, alpha) @jax.jit # TODO: add static_argnames=("xi") ? def e_c_scan(n: jax.Array, s:jax.Array, xi: jax.Array, alpha: jax.Array) -> jax.Array: """ The correlation energy per particle of SCAN TODO: verify xi != 0 correctness if polarized systems are added https://github.com/ElectronicStructureLibrary/libxc/blob/master/maple/mgga_exc/mgga_c_scan.mpl """ # fit parameters c1c = 0.64 c2c = 1.5 dc = 0.7 #fixed constants b1c = 0.0285764 b2c = 0.0889 b3c = 0.125541 def Psi_fn(xi: jax.Array) -> jax.Array: return ((1 + xi) ** (2 / 3) + (1 - xi) ** (2 / 3)) / 2 def ec1_fn(n: jax.Array, s:jax.Array, xi: jax.Array) -> jax.Array: """ Perdew-Ernzerhof-Wang 1996 (PEW96)-like correlation energy https://github.com/ElectronicStructureLibrary/libxc/blob/master/maple/gga_exc/gga_c_pbe.mpl """ ec_LSDA1 = ec_LSDA1_fn(n, xi) def H1(r_s: jax.Array, s:jax.Array, xi: jax.Array) -> jax.Array: # gamma = 0.031091 gamma = (1 - jnp.log(2)) / jnp.pi**2 # beta = 0.06672455060314922 beta = 0.066725 * (1 + 0.1 * r_s) / (1 + 0.1778 * r_s) # SCAN # beta = 0.066725 # PBE0 Psi = Psi_fn(xi) w1 = jnp.expm1(- ec_LSDA1 / (gamma * Psi**3)) A = beta / (gamma * w1) t = ((3 * pi**2 / 16) ** (1 / 3) * s) / (Psi * jnp.sqrt(r_s)) g = (1 + 4 * A * t**2)**(- 1 / 4) # g = 1 / (1 + A * t**2 + A**2 * t**4) # PBE0 return gamma * Psi**3 * jnp.log1p(w1 * (1 - g)) return ec_LSDA1 + H1(r_s, s, xi) def ec0_fn(r_s: jax.Array, s:jax.Array, xi: jax.Array) -> jax.Array: ec_LDA0 = - b1c / (1 + b2c * jnp.sqrt(r_s) + b3c * r_s) def dx(xi: jax.Array) -> jax.Array: return ((1 + xi)**(4 / 3) + (1 - xi)**(4 / 3)) / 2 def Gc(xi: jax.Array) -> jax.Array: return (1 - 2.3631 * (dx(xi) - 1)) * (1 - xi**12) w0 = jnp.expm1(- ec_LDA0 / b1c) chi_unpolarized = 0.12802585262625815 # 0.128026 g = 1 / (1 + 4 * chi_unpolarized * s**2)**(1 / 4) H0 = b1c * jnp.log1p(w0 * (1 - g)) return (ec_LDA0 + H0) * Gc(xi) r_s = _calc_r_s(n) ec1 = ec1_fn(n, s, xi) fc_alpha = _f_interp(alpha, c1c, c2c, dc) return ec1 + fc_alpha * (ec0_fn(r_s, s, xi) - ec1) @partial(jax.jit, static_argnames=("xi")) def e_xc_scan(n: jax.Array, s:jax.Array, xi: jax.Array, alpha: jax.Array) -> jax.Array: """ The exchange-correlation energy per particle of SCAN """ return e_x_scan(n, s, alpha) + e_c_scan(n, s, xi, alpha)
Thank you @ESEberhard, SCAN is a functional I have wanted to learn more about and integrate into MESS so this is a great help 🤗
Hi guys,
I thought you might find our implementation of SCAN for EG-XC useful to integrate into the code base. I attached it below :)