google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.78k stars 2.72k forks source link

jacobians and hessians for an application involving complex numbers #603

Closed proteneer closed 5 years ago

proteneer commented 5 years ago

The imaginary components of the derivatives of the following complex-valued functions are always zero in jax. In contrast, tensorflow returns a non zero imaginary component. I think this is a bug on the JAX side of things.

from jax.config import config; config.update("jax_enable_x64", True)
import jax
import jax.numpy as np
import numpy as onp

import tensorflow as tf

zs = 0.5j * np.arange(5) + np.arange(5)

print("input", zs)

def fn(z):
    return np.cos(np.linalg.norm(z*2))

grad = jax.jacfwd(fn)
print("jax", fn(zs), grad(zs))

def tf_fn(z):
    return tf.cos(tf.norm(z*2))

tf_zs = tf.convert_to_tensor(0.5j * onp.arange(5) + onp.arange(5))
tf_res = tf_fn(tf_zs)

sess = tf.Session()

grad_ys = tf.ones_like(tf_res)
grad_op = tf.gradients(tf_res, tf_zs, grad_ys=grad_ys)
print("tf", sess.run([tf_res, grad_op, grad_ys]))
input [0.+0.j  1.+0.5j 2.+1.j  3.+1.5j 4.+2.j ]
jax 0.9495740004388323 [0.        +0.j 0.10240272+0.j 0.20480544+0.j 0.30720815+0.j
 0.40961087+0.j]
tf [(0.9495740004388323+0j), [array([0.        +0.j        , 0.10240272+0.05120136j,
       0.20480544+0.10240272j, 0.30720815+0.15360408j,
       0.40961087+0.20480544j])], (1+0j)]
mattjj commented 5 years ago

Did you review the explanation of our complex number differentiation convention at the bottom of our Autodiff Cookbook? See also the description of Autograd's complex number differentiation convention.

I wonder if this could be just a different convention. What do you think?

mattjj commented 5 years ago

Caveat: I didn't look at your code! I just wanted to flag the explanation of our complex differentiation. It could be a bug, in which case we'll fix it!

mattjj commented 5 years ago

I added Autograd to your script, which JAX should match (but is more mature because we've worked on it for longer):

import autograd
import autograd.numpy as np
import numpy as onp

zs = 0.5j * np.arange(5) + np.arange(5)

print("input", zs)

def fn(z):
    return np.cos(np.linalg.norm(z*2))

grad = autograd.jacobian(fn)
print("autograd", fn(zs), grad(zs))
('input', array([0.+0.j , 1.+0.5j, 2.+1.j , 3.+1.5j, 4.+2.j ]))
('autograd', 0.9495740004388323, array([0.        +0.j        , 0.10240272+0.05120136j,
       0.20480544+0.10240272j, 0.30720815+0.15360408j,
       0.40961087+0.20480544j]))

So definitely a bug!

mattjj commented 5 years ago

Looks like it's an issue with jax.jacfwd and not jax.grad:

from __future__ import print_function

import jax
import autograd
import autograd.numpy as anp
import jax.numpy as jnp
import numpy as onp

zs = 0.5j * onp.arange(5) + onp.arange(5)
print("input", zs)

def autograd_fn(z):
    return anp.cos(anp.linalg.norm(z*2))
print("autograd", autograd_fn(zs), autograd.grad(autograd_fn)(zs))

def jax_fn(z):
    return jnp.cos(jnp.linalg.norm(z*2))
print("jax", jax_fn(zs), jax.grad(jax_fn)(zs))
input [0.+0.j  1.+0.5j 2.+1.j  3.+1.5j 4.+2.j ]
autograd 0.9495740004388323 [0.        +0.j         0.10240272+0.05120136j 0.20480544+0.10240272j
 0.30720815+0.15360408j 0.40961087+0.20480544j]
jax/lib/xla_bridge.py:144: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
jax 0.94957405 [0.        +0.j         0.10240266-0.05120133j 0.20480531-0.10240266j
 0.30720797-0.15360399j 0.40961063-0.20480531j]

EDIT: Actually, jax.grad and autograd.grad differ by a conjugation here, but at least it's not zeroing things out.

mattjj commented 5 years ago

We think it's because jacfwd (and jacrev) are always being applied to a standard basis of tangent vectors with real rather than complex dtype. We need to create that onp.eye with a dtype that is consistent with the dtype of the input to the function being differentiated. We probably meant to revisit that line... lazy on our parts!

proteneer commented 5 years ago

Thanks for looking into this. I did get a chance to look at the caveats re: holomorphic functions, but as you've already discovered, that wasn't quite the issue. Just to summarize:

-tensorflow + autograd agree with each other -autograd + jaxgrad differ by conjugation -nothing agrees with jacfwd

Thanks for the fast responses!

mattjj commented 5 years ago

The quick fix I tried wasn't right, so we'll merge #604 as a stopgap so at least this isn't a silent failure. It raises errors if you try to use jacfwd / jacrev with complex types.

mattjj commented 5 years ago

I just started #610 which might close this issue. It doesn't make the code in the original post work with jacfwd; instead, that raises an error! But it does squash a bug to make the above code work with jacrev (instead of jacfwd).

Why does it work with jacrev and not jacfwd? My current best explanation is that it seems the complex differentiation convention we use works really nicely for this specific example because of the transposition (adjoint operation) involved, but not so easily with forward-mode. In reverse-mode, we linearize a C -> R function (like this one) and transpose the result to ultimately produce a linear R -> C function, which we can use to pull back a real-valued standard cotangent basis and produce a complex cotangent result. But in forward-mode, we'd have a linear C -> R function, yet we'd want to push forward a whole complex standard tangent basis to get a complex tangent result. The types don't seem to work out.

We might be able to implement our R^2 -> R convention more directly to make jacfwd work, but we'd need to rewrite primitives like np.real and np.imag to be "first component" and "second component", respectively. It's feasible to do that with our transformation machinery, but it'd take some work and we'd want to have clear use cases in mind. That proved unnecessary.

Any thoughts?

proteneer commented 5 years ago

I don't understand the internals of JAX well enough at this point to comment on the technical difficulties of implementing jacfwd. Unfortunately I do need jacfwd for the following reasons:

1) Computing the hessian of Ewald (E) energy of periodic systems (http://micro.stanford.edu/mediawiki/images/4/46/Ewald_notes.pdf). The JAX tutorial recommends composing using jacfwd and jacrev to achieve this.

2) In the case where I already have an expression for the forces (dE/dx), I'm trying to implement second order optimization, I necessarily need jacfwd to avoid computing the intermediate values required by jacrev. For dynamical systems (unlike quasi-newton based approaches), the number of steps required is extremely large (>1,000,000) where storing intermediate values become highly impractical.

Note that the energy functions used are strictly in the 3rd use case of the JAX convention, that is, E maps from R->R but goes through intermediate holomorphic R->C and C->Rs.

mattjj commented 5 years ago

Wow, cool use case! We'll make JAX work for this :)

I discussed with @dougalm and got some new ideas. We think jacrev will naturally handle C->R functions, and jacfwd will naturally handle R->C functions. Either can be used to differentiate holomorphic C->C functions (because in that case we only need to pull back or push forward a real standard basis, as that will reveal all the derivative information), and by putting these together with vmap we can get jacfwd to work with C->R functions, jacrev to work with R->C functions, and can even get either to differentiate non-holomorphic C->C functions, though we'll need to explain how. I'm going to revise #610 in accordance with this plan, and then probably explain how to put them together to accomplish the things in the previous sentence. (EDIT: revised this paragraph a bit.)

We might need to unpack a bit more about your use case to work out the right way to compose these things. Can you make a toy example that shows the kind of Hessian you need to compute?

mattjj commented 5 years ago

Hopefully this comment unpacks things a bit. If you provide a model use case, I'd be happy to help think through how to compose these pieces together to build exactly what you want.

mattjj commented 5 years ago

I switched the label on this issue because #610 fixed the bug, but we're still discussing how to compose the pieces for this application.

proteneer commented 5 years ago

Thanks for the detailed explanation quick response. I have some reference tensorflow-based Ewald code that has fully working Hessians that I can try and port over to jax. I don't know autograd (nor jax) well enough. Then we can compare the two hessian implementation as a first step (assuming you're okay with tensorflow being used as the reference).

mattjj commented 5 years ago

Sure, sounds great! We're familiar with TF :)

proteneer commented 5 years ago

See attached for a realistic failed use-case using the following pip versions:

jax==0.1.22 jaxlib==0.1.12

If you guys feel like master is ready to go, I can switch to master and see if the results are improved.

import time
import numpy as onp
from jax.config import config; config.update("jax_enable_x64", True)
import jax
import jax.numpy as np
import tensorflow as tf

BOLTZMANN = 1.380658e-23
AVOGADRO = 6.0221367e23
RGAS = BOLTZMANN*AVOGADRO
BOLTZ = RGAS/1000
ONE_4PI_EPS0 = 138.935456
VIBRATIONAL_CONSTANT = 1302.79 # http://openmopac.net/manual/Hessian_Matrix.html

class EwaldEnergy():

    def __init__(self, kmax, charges, alphaEwald, box):
        self.kmax = kmax
        self.charges = charges
        self.alphaEwald = alphaEwald
        self.box = box

        self.recipBoxSize = (2*np.pi)/box

        self.mg = []
        lowry = 0
        lowrz = 1

        numRx, numRy, numRz = self.kmax, self.kmax, self.kmax

        for rx in range(numRx):
            for ry in range(lowry, numRy):
                for rz in range(lowrz, numRz):
                    self.mg.append((rx, ry, rz))
                    lowrz = 1 - numRz
                lowry = 1 - numRy

        self.mg = onp.array(self.mg)

    def jax_reciprocal_energy(self, conf):

        # lattice vectors
        ki = np.expand_dims(self.recipBoxSize, axis=0) * self.mg # [nk, 3]
        ri = np.expand_dims(conf, axis=0) # [1, N, 3]
        rik = np.sum(np.multiply(ri, np.expand_dims(ki, axis=1)), axis=-1) # [nk, N]
        real = np.cos(rik)
        imag = np.sin(rik)
        # eikr = np.complex(real, imag) # [nk, N]
        eikr = real + 1j*imag
        # qi = np.complex(self.charges, np.float64(0.0))
        qi = self.charges + 1j*0
        Sk = np.sum(qi*eikr, axis=-1)  # [nk]
        n2Sk = np.power(np.abs(Sk), 2)
        k2 = np.sum(np.multiply(ki, ki), axis=-1) # [nk]
        factorEwald = -1/(4*self.alphaEwald*self.alphaEwald)
        ak = np.exp(k2*factorEwald)/k2 # [nk]
        nrg = np.sum(ak * n2Sk)
        recipCoeff = (ONE_4PI_EPS0*4*np.pi)/(self.box[0]*self.box[1]*self.box[2])

        return recipCoeff * nrg

    def tf_reciprocal_energy(self, conf):

        # lattice vectors
        ki = tf.expand_dims(self.recipBoxSize, axis=0) * self.mg # [nk, 3]
        ri = tf.expand_dims(conf, axis=0) # [1, N, 3]
        rik = tf.reduce_sum(tf.multiply(ri, tf.expand_dims(ki, axis=1)), axis=-1) # [nk, N]
        real = tf.cos(rik)
        imag = tf.sin(rik)
        eikr = tf.complex(real, imag) # [nk, N]
        qi = tf.complex(self.charges, np.float64(0.0))
        Sk = tf.reduce_sum(qi*eikr, axis=-1)  # [nk]
        n2Sk = tf.pow(tf.abs(Sk), 2)
        k2 = tf.reduce_sum(tf.multiply(ki, ki), axis=-1) # [nk]
        factorEwald = -1/(4*self.alphaEwald*self.alphaEwald)
        ak = tf.exp(k2*factorEwald)/k2 # [nk]
        nrg = tf.reduce_sum(ak * n2Sk)
        recipCoeff = (ONE_4PI_EPS0*4*np.pi)/(self.box[0]*self.box[1]*self.box[2])

        return recipCoeff * nrg

if __name__ == "__main__":

    charges = onp.array([
        0.1,
        -0.1,
        0.3,
        0.15,
        -0.4
    ], dtype=np.float64)

    ee = EwaldEnergy(
        kmax=4, 
        charges=charges, 
        alphaEwald=1.0,
        box=onp.array([4.0, 4.0, 4.0], dtype=np.float64))

    x0 = onp.array([
        [ 0.0637,   0.0126,   0.2203],
        [ 1.0573,  -0.2011,   1.2864],
        [ 2.3928,   1.2209,  -0.2230],
        [-0.6891,   1.6983,   0.0780],
        [-0.6312,  -1.6261,  -0.2601]
    ], dtype=np.float64)

    xt = tf.convert_to_tensor(x0)

    nrg_op = ee.tf_reciprocal_energy(xt)
    grad_op = tf.gradients(nrg_op, xt)[0]
    hess_op = tf.hessians(nrg_op, xt)

    sess = tf.Session()
    nrg_tf = sess.run([nrg_op])

    nrg_jax = ee.jax_reciprocal_energy(x0)
    onp.testing.assert_almost_equal(nrg_tf, nrg_jax)

    grad_tf = sess.run([grad_op])[0]

    grad_jax_rev_fn = jax.jacrev(ee.jax_reciprocal_energy)
    grad_jax_rev = grad_jax_rev_fn(x0)

    # grad_rev passes
    onp.testing.assert_almost_equal(grad_tf, grad_jax_rev)

    grad_jax_fwd_fn = jax.jacfwd(ee.jax_reciprocal_energy)
    grad_jax_fwd = grad_jax_fwd_fn(x0)

    # grad_fwd passes
    onp.testing.assert_almost_equal(grad_tf, grad_jax_fwd)

    hess_jax_fwd_rev_fn = jax.jacfwd(jax.jacrev(ee.jax_reciprocal_energy))
    hess_jax_fwd_rev = hess_jax_fwd_rev_fn(x0)

    hess_tf = sess.run([hess_op])[0][0]

    # hessian fails
    onp.testing.assert_almost_equal(hess_tf, hess_jax_fwd_rev)

Failure:

AssertionError: 
Arrays are not almost equal to 7 decimals

(mismatch 100.0%)
 x: array([ 3.4778180e-01, -7.8731717e-01, -2.6087100e-01,  1.6491354e-01,
       -1.0239034e-01,  1.6205996e-01, -3.6835986e-01,  4.1750644e-02,
       -2.6607491e-01,  1.3209316e-01,  1.9315429e-01,  2.4039437e-02,...
 y: array([ 9.4108600e-01, -7.4349972e-01, -3.3325226e-01,  6.7568767e-02,
       -5.6032348e-02,  2.6268124e-01, -7.7064368e-01,  2.4751300e-01,
       -1.0244276e-01,  1.0012433e-01,  1.5247021e-01, -2.4254655e-02,...
proteneer commented 5 years ago

To add a little more: what's interesting is that both first order jacrev/jacfwd's pass without issue. But in this case it's the second order composition that fails.

proteneer commented 5 years ago

One more thing to add: there's a non-holomorphic function used: tf.abs(Sk)

An alternative analytic definition can be found here as part of the complex step derivative trick:

http://mdolab.engin.umich.edu/sites/default/files/Martins2003CSD.pdf (equation 13)

mattjj commented 5 years ago

Thanks so much for this test case!

EDIT: I should mention, I got the same numerical behavior as in your preceding messages when I checked on master. That's what led to the further investigation below.

I added this code to the bottom of your script to do a quick check of the respective local quadratic approximations in a random direction:

import matplotlib.pyplot as plt
from jax import vmap

v = onp.random.RandomState(0).randn(*x0.shape)
v = v / onp.sqrt(onp.sum(v**2))
t = onp.linspace(-1, 1, 1000)
displacements = t[:, None, None] * v

def quad_approx(grad, hess):
  return (ee.jax_reciprocal_energy(x0) +
          onp.einsum('ij,nij->n', grad, displacements) +
          0.5 * onp.einsum('ijkl,nij,nkl->n', hess,
                            displacements, displacements))

plt.plot(t, vmap(ee.jax_reciprocal_energy)(x0 + displacements), 'k-')
jax_approx = quad_approx(jax.grad(ee.jax_reciprocal_energy)(x0),
                          jax.hessian(ee.jax_reciprocal_energy)(x0))
plt.plot(t, jax_approx, 'b--', label='jax')
tf_approx = quad_approx(grad_tf, hess_tf)
plt.plot(t, tf_approx, 'r--', label='tf')
plt.legend(loc='best')
plt.savefig('approx.png')

EDIT: revised code to factor out the quad_approx function, making it clearer the approximations were formed the same way.

approx

(Please double-check and run this code yourself to check that it's sensible and to see if you can reproduce. This is against the JAX master branch.)

That plot makes me think that, while JAX and TF disagree on the Hessian, the JAX Hessian looks like it's providing a more accurate local approximation to the function.

I tried a few more random directions by incrementing the random seed:

approx1

approx2

I also checked that our underlying VJP and JVP definitions are working correctly by running this code, which tests against numerical differences in random directions (but which doesn't check jacfwd and jacrev wrappers, just the more basic vjp and jvp functions that underlie them):

from jax.test_util import check_grads
check_grads(ee.jax_reciprocal_energy, (x0,), order=2)

Contingent on you checking over the above code and reproducing the results, this suggests to me that we should do some more numerical checks for accuracy (gradients and Hessians are easy to check against finite differences!), but not necessarily assume that the TF calculation is accurate.

What do you think?

mattjj commented 5 years ago

Notice that this is comparing against the ee.jax_reciprocal_energy function. It could be that the two energy function implementations differ; I didn't check that.

proteneer commented 5 years ago

The two implementations are identical otherwise the energies and forces wouldn't agree. I've been able to reproduce the first figure locally as well. autograd also seems to agree with jax.

It also seems that replacing the non-analytic abs function does make jax and tensorflow agree:

...
        n2Sk = np.power(np.real(Sk), 2)
...
        n2Sk = tf.pow(tf.real(Sk), 2)

At this point I'm lost. I'm suspicious of purely numerical issues because everything is done in 64-bit precision. I'm pretty sure we can isolate this down to a much simpler repro by using the abs and some basic trig functions again.

mattjj commented 5 years ago

Could it be a bug with the TF hessian? It doesn’t look accurate in those plots.

proteneer commented 5 years ago

Looks like it. A much simpler repro comparing against autograd:

import autograd
import autograd.numpy as anp
import numpy as onp
import tensorflow as tf

zs = 2 + 0.5j  

print("input", zs)

def fn(z):
    return anp.abs(z)

grad = autograd.jacobian(fn)
print("ag", fn(zs), grad(zs))

def tf_fn(z):
    return tf.abs(z)

tf_zs = tf.convert_to_tensor(2 + 0.5j)
tf_res = tf_fn(tf_zs)

sess = tf.Session()

grad_ys = tf.ones_like(tf_res)
grad_op = tf.gradients(tf_res, tf_zs, grad_ys=grad_ys)
print("tf", sess.run([tf_res, grad_op, grad_ys]))

Results differ again, by conjugation:

input (2+0.5j)
ag 2.0615528128088303 (0.9701425001453319-0.24253562503633297j)
tf [2.0615528128088303, [(0.9701425001453319+0.24253562503633297j)], 1.0]

Edit: Note that this reveals differences even in the first order derivative.

mattjj commented 5 years ago

Differing by conjugation in the first-order derivatives may be a separate issue, and it could be that the convention used by Autograd and JAX just differs there. The difference in the Hessian seems like a much clearer bug, since it just doesn't agree with the graphical numerical checks above.

We should make a minimal numerical repro for the Hessian issue, purely in TF, which would just look like evaluating the TF gradient at x and at x + eps * v and taking the difference, then comparing to the Hessian applied to eps * v. I'm a bit rusty with my TF; want to take a stab at that?

proteneer commented 5 years ago

Confirmed bug in tensorflow:

import numpy as onp
import autograd as ag
import autograd.numpy as anp
import numpy as onp
import tensorflow as tf

inp = anp.array(2.0)

print("input", inp)

def ag_fn(x):
    real = anp.cos(x+2)
    imag = anp.sin(x-1)
    return anp.abs(real+1j*imag)

ag_hess = ag.hessian(ag_fn)

print("ag val:", ag_fn(inp))
print("ag hess:", ag_hess(inp))

def tf_fn(x):
    real = tf.cos(x+2)
    imag = tf.sin(x-1)
    return tf.abs(tf.complex(real, imag))

# tf_inp = tf.convert_to_tensor(inp)
tf_inp = tf.placeholder(shape=tuple(), dtype=onp.float64)

out_op = tf_fn(tf_inp)

tf_grad = tf.gradients(out_op, tf_inp)[0]
tf_hess = tf.hessians(out_op, tf_inp)[0]

sess = tf.Session()
delta = 1e-7

_, d0, tf_ad = sess.run([out_op, tf_grad, tf_hess], feed_dict={tf_inp: inp})
_, d1, _ = sess.run([out_op, tf_grad, tf_hess], feed_dict={tf_inp: inp+delta})

print("tf_numerical derivative:", (d1-d0)/delta)
print("tf_autodiff derivative:", tf_ad)
input 2.0
ag val: 1.0655155566059393
ag hess: -0.25533014019223726
2019-04-14 22:55:43.481283: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
tf_numerical derivative: -0.25533013481293665
tf_autodiff derivative: -1.0655155566059389
proteneer commented 5 years ago

At least for the second order case, this doesn't seem to be a problem on your end. So it seems like this has actually triggered a wholly independent bug of what I had initially reported.

mattjj commented 5 years ago

Should we close this issue now, or is there more to figure out here?

proteneer commented 5 years ago

After discussing with @dougalm earlier this week, I think for my use case of a function F: R->R that goes through a set of non holomorphic intermediates the current solution suffices. If I understand correctly, I think jacfwd still needs to be fixed for a certain subset of functions since your pull request only fixed jacrev. We can open a separate issue for that if desired and close this one.

proteneer commented 5 years ago

Again, thanks for all the help on this issue!

mattjj commented 5 years ago

Happy to help! This is exactly the sort of sophisticated use case we want JAX to support.

Actually jacfwd and jacrev are both in a good place now, and we don't have any planned revisions for them. They're pleasingly symmetric: jacfwd works for R->C or holomorphic C->C, and jacrev works with C->R or holomorphic C->C. In other cases they raise errors (rather than silently failing).

It makes sense to restrict these jacfwd and jacrev functions to those respective use cases because those are the cases when they can push forward or pull back (respectively) a real standard basis and produce an answer that has the same type/dimension as the output or input of the function (respectively). The underlying JAX jvp/vjp+vmap machinery can also handle the other cases, like computing the Jacobian of a C->R function (or non-holomorphic C->C function) with forward-mode, but we might prefer to put those in separate functions, rather than cramming that behavior into jacfwd and jacrev, because the bookkeeping with dimensions might require more coordination with the user (at least via a new docstring). In any case we can add functions for those use cases when the need arises, though for now the way to differentiate those kinds of non-holomorphic functions in JAX is to roll your own using jvp/vjp+vmap.

Thanks again for spotting our jacfwd/jacrev dtype bug, and for bringing up this application! Looking forward to hearing more about it, and further improving JAX to support work like yours.