Closed proteneer closed 5 years ago
Did you review the explanation of our complex number differentiation convention at the bottom of our Autodiff Cookbook? See also the description of Autograd's complex number differentiation convention.
I wonder if this could be just a different convention. What do you think?
Caveat: I didn't look at your code! I just wanted to flag the explanation of our complex differentiation. It could be a bug, in which case we'll fix it!
I added Autograd to your script, which JAX should match (but is more mature because we've worked on it for longer):
import autograd
import autograd.numpy as np
import numpy as onp
zs = 0.5j * np.arange(5) + np.arange(5)
print("input", zs)
def fn(z):
return np.cos(np.linalg.norm(z*2))
grad = autograd.jacobian(fn)
print("autograd", fn(zs), grad(zs))
('input', array([0.+0.j , 1.+0.5j, 2.+1.j , 3.+1.5j, 4.+2.j ]))
('autograd', 0.9495740004388323, array([0. +0.j , 0.10240272+0.05120136j,
0.20480544+0.10240272j, 0.30720815+0.15360408j,
0.40961087+0.20480544j]))
So definitely a bug!
Looks like it's an issue with jax.jacfwd
and not jax.grad
:
from __future__ import print_function
import jax
import autograd
import autograd.numpy as anp
import jax.numpy as jnp
import numpy as onp
zs = 0.5j * onp.arange(5) + onp.arange(5)
print("input", zs)
def autograd_fn(z):
return anp.cos(anp.linalg.norm(z*2))
print("autograd", autograd_fn(zs), autograd.grad(autograd_fn)(zs))
def jax_fn(z):
return jnp.cos(jnp.linalg.norm(z*2))
print("jax", jax_fn(zs), jax.grad(jax_fn)(zs))
input [0.+0.j 1.+0.5j 2.+1.j 3.+1.5j 4.+2.j ]
autograd 0.9495740004388323 [0. +0.j 0.10240272+0.05120136j 0.20480544+0.10240272j
0.30720815+0.15360408j 0.40961087+0.20480544j]
jax/lib/xla_bridge.py:144: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
jax 0.94957405 [0. +0.j 0.10240266-0.05120133j 0.20480531-0.10240266j
0.30720797-0.15360399j 0.40961063-0.20480531j]
EDIT: Actually, jax.grad
and autograd.grad
differ by a conjugation here, but at least it's not zeroing things out.
We think it's because jacfwd
(and jacrev
) are always being applied to a standard basis of tangent vectors with real rather than complex dtype. We need to create that onp.eye
with a dtype that is consistent with the dtype of the input to the function being differentiated. We probably meant to revisit that line... lazy on our parts!
Thanks for looking into this. I did get a chance to look at the caveats re: holomorphic functions, but as you've already discovered, that wasn't quite the issue. Just to summarize:
-tensorflow + autograd agree with each other -autograd + jaxgrad differ by conjugation -nothing agrees with jacfwd
Thanks for the fast responses!
The quick fix I tried wasn't right, so we'll merge #604 as a stopgap so at least this isn't a silent failure. It raises errors if you try to use jacfwd
/ jacrev
with complex types.
I just started #610 which might close this issue. It doesn't make the code in the original post work with jacfwd
; instead, that raises an error! But it does squash a bug to make the above code work with jacrev
(instead of jacfwd
).
Why does it work with jacrev
and not jacfwd
? My current best explanation is that it seems the complex differentiation convention we use works really nicely for this specific example because of the transposition (adjoint operation) involved, but not so easily with forward-mode. In reverse-mode, we linearize a C -> R
function (like this one) and transpose the result to ultimately produce a linear R -> C
function, which we can use to pull back a real-valued standard cotangent basis and produce a complex cotangent result. But in forward-mode, we'd have a linear C -> R
function, yet we'd want to push forward a whole complex standard tangent basis to get a complex tangent result. The types don't seem to work out.
We might be able to implement our That proved unnecessary.R^2 -> R
convention more directly to make jacfwd
work, but we'd need to rewrite primitives like np.real
and np.imag
to be "first component" and "second component", respectively. It's feasible to do that with our transformation machinery, but it'd take some work and we'd want to have clear use cases in mind.
Any thoughts?
I don't understand the internals of JAX well enough at this point to comment on the technical difficulties of implementing jacfwd
. Unfortunately I do need jacfwd
for the following reasons:
1) Computing the hessian of Ewald (E) energy of periodic systems (http://micro.stanford.edu/mediawiki/images/4/46/Ewald_notes.pdf). The JAX tutorial recommends composing using jacfwd and jacrev to achieve this.
2) In the case where I already have an expression for the forces (dE/dx), I'm trying to implement second order optimization, I necessarily need jacfwd to avoid computing the intermediate values required by jacrev. For dynamical systems (unlike quasi-newton based approaches), the number of steps required is extremely large (>1,000,000) where storing intermediate values become highly impractical.
Note that the energy functions used are strictly in the 3rd use case of the JAX convention, that is, E maps from R->R but goes through intermediate holomorphic R->C and C->Rs.
Wow, cool use case! We'll make JAX work for this :)
I discussed with @dougalm and got some new ideas. We think jacrev
will naturally handle C->R functions, and jacfwd
will naturally handle R->C functions. Either can be used to differentiate holomorphic C->C functions (because in that case we only need to pull back or push forward a real standard basis, as that will reveal all the derivative information), and by putting these together with vmap
we can get jacfwd
to work with C->R functions, jacrev
to work with R->C functions, and can even get either to differentiate non-holomorphic C->C functions, though we'll need to explain how. I'm going to revise #610 in accordance with this plan, and then probably explain how to put them together to accomplish the things in the previous sentence. (EDIT: revised this paragraph a bit.)
We might need to unpack a bit more about your use case to work out the right way to compose these things. Can you make a toy example that shows the kind of Hessian you need to compute?
Hopefully this comment unpacks things a bit. If you provide a model use case, I'd be happy to help think through how to compose these pieces together to build exactly what you want.
I switched the label on this issue because #610 fixed the bug, but we're still discussing how to compose the pieces for this application.
Thanks for the detailed explanation quick response. I have some reference tensorflow-based Ewald code that has fully working Hessians that I can try and port over to jax. I don't know autograd (nor jax) well enough. Then we can compare the two hessian implementation as a first step (assuming you're okay with tensorflow being used as the reference).
Sure, sounds great! We're familiar with TF :)
See attached for a realistic failed use-case using the following pip versions:
jax==0.1.22 jaxlib==0.1.12
If you guys feel like master is ready to go, I can switch to master and see if the results are improved.
import time
import numpy as onp
from jax.config import config; config.update("jax_enable_x64", True)
import jax
import jax.numpy as np
import tensorflow as tf
BOLTZMANN = 1.380658e-23
AVOGADRO = 6.0221367e23
RGAS = BOLTZMANN*AVOGADRO
BOLTZ = RGAS/1000
ONE_4PI_EPS0 = 138.935456
VIBRATIONAL_CONSTANT = 1302.79 # http://openmopac.net/manual/Hessian_Matrix.html
class EwaldEnergy():
def __init__(self, kmax, charges, alphaEwald, box):
self.kmax = kmax
self.charges = charges
self.alphaEwald = alphaEwald
self.box = box
self.recipBoxSize = (2*np.pi)/box
self.mg = []
lowry = 0
lowrz = 1
numRx, numRy, numRz = self.kmax, self.kmax, self.kmax
for rx in range(numRx):
for ry in range(lowry, numRy):
for rz in range(lowrz, numRz):
self.mg.append((rx, ry, rz))
lowrz = 1 - numRz
lowry = 1 - numRy
self.mg = onp.array(self.mg)
def jax_reciprocal_energy(self, conf):
# lattice vectors
ki = np.expand_dims(self.recipBoxSize, axis=0) * self.mg # [nk, 3]
ri = np.expand_dims(conf, axis=0) # [1, N, 3]
rik = np.sum(np.multiply(ri, np.expand_dims(ki, axis=1)), axis=-1) # [nk, N]
real = np.cos(rik)
imag = np.sin(rik)
# eikr = np.complex(real, imag) # [nk, N]
eikr = real + 1j*imag
# qi = np.complex(self.charges, np.float64(0.0))
qi = self.charges + 1j*0
Sk = np.sum(qi*eikr, axis=-1) # [nk]
n2Sk = np.power(np.abs(Sk), 2)
k2 = np.sum(np.multiply(ki, ki), axis=-1) # [nk]
factorEwald = -1/(4*self.alphaEwald*self.alphaEwald)
ak = np.exp(k2*factorEwald)/k2 # [nk]
nrg = np.sum(ak * n2Sk)
recipCoeff = (ONE_4PI_EPS0*4*np.pi)/(self.box[0]*self.box[1]*self.box[2])
return recipCoeff * nrg
def tf_reciprocal_energy(self, conf):
# lattice vectors
ki = tf.expand_dims(self.recipBoxSize, axis=0) * self.mg # [nk, 3]
ri = tf.expand_dims(conf, axis=0) # [1, N, 3]
rik = tf.reduce_sum(tf.multiply(ri, tf.expand_dims(ki, axis=1)), axis=-1) # [nk, N]
real = tf.cos(rik)
imag = tf.sin(rik)
eikr = tf.complex(real, imag) # [nk, N]
qi = tf.complex(self.charges, np.float64(0.0))
Sk = tf.reduce_sum(qi*eikr, axis=-1) # [nk]
n2Sk = tf.pow(tf.abs(Sk), 2)
k2 = tf.reduce_sum(tf.multiply(ki, ki), axis=-1) # [nk]
factorEwald = -1/(4*self.alphaEwald*self.alphaEwald)
ak = tf.exp(k2*factorEwald)/k2 # [nk]
nrg = tf.reduce_sum(ak * n2Sk)
recipCoeff = (ONE_4PI_EPS0*4*np.pi)/(self.box[0]*self.box[1]*self.box[2])
return recipCoeff * nrg
if __name__ == "__main__":
charges = onp.array([
0.1,
-0.1,
0.3,
0.15,
-0.4
], dtype=np.float64)
ee = EwaldEnergy(
kmax=4,
charges=charges,
alphaEwald=1.0,
box=onp.array([4.0, 4.0, 4.0], dtype=np.float64))
x0 = onp.array([
[ 0.0637, 0.0126, 0.2203],
[ 1.0573, -0.2011, 1.2864],
[ 2.3928, 1.2209, -0.2230],
[-0.6891, 1.6983, 0.0780],
[-0.6312, -1.6261, -0.2601]
], dtype=np.float64)
xt = tf.convert_to_tensor(x0)
nrg_op = ee.tf_reciprocal_energy(xt)
grad_op = tf.gradients(nrg_op, xt)[0]
hess_op = tf.hessians(nrg_op, xt)
sess = tf.Session()
nrg_tf = sess.run([nrg_op])
nrg_jax = ee.jax_reciprocal_energy(x0)
onp.testing.assert_almost_equal(nrg_tf, nrg_jax)
grad_tf = sess.run([grad_op])[0]
grad_jax_rev_fn = jax.jacrev(ee.jax_reciprocal_energy)
grad_jax_rev = grad_jax_rev_fn(x0)
# grad_rev passes
onp.testing.assert_almost_equal(grad_tf, grad_jax_rev)
grad_jax_fwd_fn = jax.jacfwd(ee.jax_reciprocal_energy)
grad_jax_fwd = grad_jax_fwd_fn(x0)
# grad_fwd passes
onp.testing.assert_almost_equal(grad_tf, grad_jax_fwd)
hess_jax_fwd_rev_fn = jax.jacfwd(jax.jacrev(ee.jax_reciprocal_energy))
hess_jax_fwd_rev = hess_jax_fwd_rev_fn(x0)
hess_tf = sess.run([hess_op])[0][0]
# hessian fails
onp.testing.assert_almost_equal(hess_tf, hess_jax_fwd_rev)
Failure:
AssertionError:
Arrays are not almost equal to 7 decimals
(mismatch 100.0%)
x: array([ 3.4778180e-01, -7.8731717e-01, -2.6087100e-01, 1.6491354e-01,
-1.0239034e-01, 1.6205996e-01, -3.6835986e-01, 4.1750644e-02,
-2.6607491e-01, 1.3209316e-01, 1.9315429e-01, 2.4039437e-02,...
y: array([ 9.4108600e-01, -7.4349972e-01, -3.3325226e-01, 6.7568767e-02,
-5.6032348e-02, 2.6268124e-01, -7.7064368e-01, 2.4751300e-01,
-1.0244276e-01, 1.0012433e-01, 1.5247021e-01, -2.4254655e-02,...
To add a little more: what's interesting is that both first order jacrev/jacfwd's pass without issue. But in this case it's the second order composition that fails.
One more thing to add: there's a non-holomorphic function used: tf.abs(Sk)
An alternative analytic definition can be found here as part of the complex step derivative trick:
http://mdolab.engin.umich.edu/sites/default/files/Martins2003CSD.pdf (equation 13)
Thanks so much for this test case!
EDIT: I should mention, I got the same numerical behavior as in your preceding messages when I checked on master. That's what led to the further investigation below.
I added this code to the bottom of your script to do a quick check of the respective local quadratic approximations in a random direction:
import matplotlib.pyplot as plt
from jax import vmap
v = onp.random.RandomState(0).randn(*x0.shape)
v = v / onp.sqrt(onp.sum(v**2))
t = onp.linspace(-1, 1, 1000)
displacements = t[:, None, None] * v
def quad_approx(grad, hess):
return (ee.jax_reciprocal_energy(x0) +
onp.einsum('ij,nij->n', grad, displacements) +
0.5 * onp.einsum('ijkl,nij,nkl->n', hess,
displacements, displacements))
plt.plot(t, vmap(ee.jax_reciprocal_energy)(x0 + displacements), 'k-')
jax_approx = quad_approx(jax.grad(ee.jax_reciprocal_energy)(x0),
jax.hessian(ee.jax_reciprocal_energy)(x0))
plt.plot(t, jax_approx, 'b--', label='jax')
tf_approx = quad_approx(grad_tf, hess_tf)
plt.plot(t, tf_approx, 'r--', label='tf')
plt.legend(loc='best')
plt.savefig('approx.png')
EDIT: revised code to factor out the quad_approx
function, making it clearer the approximations were formed the same way.
(Please double-check and run this code yourself to check that it's sensible and to see if you can reproduce. This is against the JAX master branch.)
That plot makes me think that, while JAX and TF disagree on the Hessian, the JAX Hessian looks like it's providing a more accurate local approximation to the function.
I tried a few more random directions by incrementing the random seed:
I also checked that our underlying VJP and JVP definitions are working correctly by running this code, which tests against numerical differences in random directions (but which doesn't check jacfwd
and jacrev
wrappers, just the more basic vjp
and jvp
functions that underlie them):
from jax.test_util import check_grads
check_grads(ee.jax_reciprocal_energy, (x0,), order=2)
Contingent on you checking over the above code and reproducing the results, this suggests to me that we should do some more numerical checks for accuracy (gradients and Hessians are easy to check against finite differences!), but not necessarily assume that the TF calculation is accurate.
What do you think?
Notice that this is comparing against the ee.jax_reciprocal_energy
function. It could be that the two energy function implementations differ; I didn't check that.
The two implementations are identical otherwise the energies and forces wouldn't agree. I've been able to reproduce the first figure locally as well. autograd also seems to agree with jax.
It also seems that replacing the non-analytic abs function does make jax and tensorflow agree:
...
n2Sk = np.power(np.real(Sk), 2)
...
n2Sk = tf.pow(tf.real(Sk), 2)
At this point I'm lost. I'm suspicious of purely numerical issues because everything is done in 64-bit precision. I'm pretty sure we can isolate this down to a much simpler repro by using the abs and some basic trig functions again.
Could it be a bug with the TF hessian? It doesn’t look accurate in those plots.
Looks like it. A much simpler repro comparing against autograd:
import autograd
import autograd.numpy as anp
import numpy as onp
import tensorflow as tf
zs = 2 + 0.5j
print("input", zs)
def fn(z):
return anp.abs(z)
grad = autograd.jacobian(fn)
print("ag", fn(zs), grad(zs))
def tf_fn(z):
return tf.abs(z)
tf_zs = tf.convert_to_tensor(2 + 0.5j)
tf_res = tf_fn(tf_zs)
sess = tf.Session()
grad_ys = tf.ones_like(tf_res)
grad_op = tf.gradients(tf_res, tf_zs, grad_ys=grad_ys)
print("tf", sess.run([tf_res, grad_op, grad_ys]))
Results differ again, by conjugation:
input (2+0.5j)
ag 2.0615528128088303 (0.9701425001453319-0.24253562503633297j)
tf [2.0615528128088303, [(0.9701425001453319+0.24253562503633297j)], 1.0]
Edit: Note that this reveals differences even in the first order derivative.
Differing by conjugation in the first-order derivatives may be a separate issue, and it could be that the convention used by Autograd and JAX just differs there. The difference in the Hessian seems like a much clearer bug, since it just doesn't agree with the graphical numerical checks above.
We should make a minimal numerical repro for the Hessian issue, purely in TF, which would just look like evaluating the TF gradient at x
and at x + eps * v
and taking the difference, then comparing to the Hessian applied to eps * v
. I'm a bit rusty with my TF; want to take a stab at that?
Confirmed bug in tensorflow:
import numpy as onp
import autograd as ag
import autograd.numpy as anp
import numpy as onp
import tensorflow as tf
inp = anp.array(2.0)
print("input", inp)
def ag_fn(x):
real = anp.cos(x+2)
imag = anp.sin(x-1)
return anp.abs(real+1j*imag)
ag_hess = ag.hessian(ag_fn)
print("ag val:", ag_fn(inp))
print("ag hess:", ag_hess(inp))
def tf_fn(x):
real = tf.cos(x+2)
imag = tf.sin(x-1)
return tf.abs(tf.complex(real, imag))
# tf_inp = tf.convert_to_tensor(inp)
tf_inp = tf.placeholder(shape=tuple(), dtype=onp.float64)
out_op = tf_fn(tf_inp)
tf_grad = tf.gradients(out_op, tf_inp)[0]
tf_hess = tf.hessians(out_op, tf_inp)[0]
sess = tf.Session()
delta = 1e-7
_, d0, tf_ad = sess.run([out_op, tf_grad, tf_hess], feed_dict={tf_inp: inp})
_, d1, _ = sess.run([out_op, tf_grad, tf_hess], feed_dict={tf_inp: inp+delta})
print("tf_numerical derivative:", (d1-d0)/delta)
print("tf_autodiff derivative:", tf_ad)
input 2.0
ag val: 1.0655155566059393
ag hess: -0.25533014019223726
2019-04-14 22:55:43.481283: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
tf_numerical derivative: -0.25533013481293665
tf_autodiff derivative: -1.0655155566059389
At least for the second order case, this doesn't seem to be a problem on your end. So it seems like this has actually triggered a wholly independent bug of what I had initially reported.
Should we close this issue now, or is there more to figure out here?
After discussing with @dougalm earlier this week, I think for my use case of a function F: R->R
that goes through a set of non holomorphic intermediates the current solution suffices. If I understand correctly, I think jacfwd still needs to be fixed for a certain subset of functions since your pull request only fixed jacrev. We can open a separate issue for that if desired and close this one.
Again, thanks for all the help on this issue!
Happy to help! This is exactly the sort of sophisticated use case we want JAX to support.
Actually jacfwd
and jacrev
are both in a good place now, and we don't have any planned revisions for them. They're pleasingly symmetric: jacfwd
works for R->C or holomorphic C->C, and jacrev
works with C->R or holomorphic C->C. In other cases they raise errors (rather than silently failing).
It makes sense to restrict these jacfwd
and jacrev
functions to those respective use cases because those are the cases when they can push forward or pull back (respectively) a real standard basis and produce an answer that has the same type/dimension as the output or input of the function (respectively). The underlying JAX jvp
/vjp
+vmap
machinery can also handle the other cases, like computing the Jacobian of a C->R function (or non-holomorphic C->C function) with forward-mode, but we might prefer to put those in separate functions, rather than cramming that behavior into jacfwd
and jacrev
, because the bookkeeping with dimensions might require more coordination with the user (at least via a new docstring). In any case we can add functions for those use cases when the need arises, though for now the way to differentiate those kinds of non-holomorphic functions in JAX is to roll your own using jvp
/vjp
+vmap
.
Thanks again for spotting our jacfwd
/jacrev
dtype bug, and for bringing up this application! Looking forward to hearing more about it, and further improving JAX to support work like yours.
The imaginary components of the derivatives of the following complex-valued functions are always zero in jax. In contrast, tensorflow returns a non zero imaginary component. I think this is a bug on the JAX side of things.