sail-sg / autofd

Automatic Functional Differentiation in JAX
Apache License 2.0
50 stars 1 forks source link

Support on customized transformations #6

Open mavenlin opened 7 months ago

mavenlin commented 7 months ago

Background

In JAX, transformations can be customized for user defined functions, for example customizing the JVP rule. This can be achieved via two ways

  1. wrap your custom function as a primitive, and register jvp rules for the primitive.
  2. use the decorator @jax.custom_jvp.

This is very useful because when you need a function with a customized gradient, for example, giving a discrete function a continuous gradient relaxation, it can be done by the register mechanism, the upper level constructs, e.g. jax.grad(f) are kept the same not matter what the jvp of f is.

Density Functional Theory

In order to perform elegant DFT, we would like to keep the high level program like the following

import autofd.operators as o
import dft.energy as e

def wave_ansatz(param, r):
  ...

def energy(crystal, param, occ):
  psi = o.partial(wave_ansatz, args=(param,), argnums=(0,))

  def psi_to_rho(psi):
    rho = o.compose(
      lambda v: jnp.sum(occ * jnp.real(jnp.conj(v) * v)), psi
    )
    return rho

  def energy_functional(psi):
    # use psi to build rho, in a functionally differentiable way.
    rho = psi_to_rho(psi)
    # compute the total energy using the functionals
    etot = (
      e.kinetic(psi) + 
      e.hartree(rho) + 
      e.external(rho, crystal) + 
      jax_xc.energy.lda_x(rho)
    )
    return etot

# gradient descent to optimize crystal, param, occ etc.
jax.grad(energy)(crystal, param, occ)

The problem with the elegant code.

One difficulty to achieve this elegance is the efficiency, if we write ugly code, we may optimize all the above energy calculations by hand deriving each terms and implement the derived formula. For example, kinetic energy of planewave is simply summation over the norm of G and k grids. Because applying laplace operator to planewaves only adds a constant factor, and that the psi is normalized and there's no need to compute the <psi|psi> because it is always 1. However, all this level of details are not available to autofd, resulting in an inefficient implementation when we follow the above code for DFT.

Why we want the elegancy?

In physics we often write the math with great simplicity, e.g. to get the energy levels, we simply solve the following eigen value problem.

$$ \left(-\frac{1}{2}\nabla^2 + \hat{V}_\text{eff}\right) \psi = \epsilon \psi $$

where

$$ \hat{V}_\text{eff} = \frac{\delta E[\rho]}{\delta\rho} $$

It takes many steps to derive this non-implementable math into implementable math, where squiggly symbol that represents functional derivative are hand derived, the integrals are discretized, and fourier transforms are used.

Then why don't we just implement the derived formula than following the functional form? It is not just for the elegancy, but also for extensibility. For example, one thing that we often do is to linearize the energy at the current value of rho, and use it as an effective potential for computing energy bands for cystals. Which gives me headache already if we were to re-derive the math for constructing fock matrix, and implement them as FFT. However, if we could support the above syntax, we could have an easier way around.

def band_structure(crystal, fixed_rho):

  def potential_energy(rho):
    return (
      e.hartree(rho) + 
      e.external(rho, crystal) + 
      jax_xc.energy.lda_x(rho)
    )

  _, eff_energy = jax.linearize(potential_energy, (fixed_rho,))

  def energy_under_veff(param, k):
    psi = o.partial(wave_ansatz, args=(param,), argnums=(0,))
    return e.kinetic(psi) + eff_energy(psi_to_rho(psi))

  bands = {}
  for k in k_vectors:
    bands[k] = jnp.eigh(jax.hessian(energy_under_veff)(param, k))[0]
  return bands

How to keep elegancy and efficiency?

Compiler is the way! Let's write elegant code and enjoy the cleanness of various tasks in DFT while relying on some compiling process convert the code to high performance. The optimizations in the compiler takes lots of rules here and there, therefore, we want to enable autofd to support custom rules when the user knows a more efficient implementation.

Custom rules for any operators

In JAX, we can customize the rules for the transpose operator, jvp operator etc. A straightforward extension is to support customized rules for all operators.

Again with kinetic energy as an example, say we have build a primitive for computing kinetic energies of wave functions.

# A custom primitive for kinetic energy
kinetic_p = core.Primitive("kinetic")

def kinetic(psi):
  kinetic_p.bind(psi)

def kinetic_impl(psi):
  # general implementation

We can customize it for a specific wave function

@autofd.function
def psi(r):
  ...

@psi.def_operator_rule(kinetic_p)
def psi_kinetic_rule(r, *, **kwargs):
  ...

Mixing different rules

One difficult question is how can we customize many different rules and how do they interfere with each other. Can we retain the custom kinetic rule when we first apply the JVP rule on psi? I need to study further before having an idea for this. (To be continued)

mavenlin commented 7 months ago

@sail-sg/ai-for-science