jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.08k stars 2.75k forks source link

Confusing/inconsistent 'shape' of off-diagonal hessian: jacfwd(jacrev(f, argnums=0), argnums=1) #4146

Open johnsalmon opened 4 years ago

johnsalmon commented 4 years ago

I'm not sure if this is a bug or a "feature". But even if it's a "feature", I think it deserves a mention in the "Sharp Bits" page.

I'm interested in the off-diagonal part of the Hessian of a function that takes two arguments. Here's some code.

import jax.numpy as jnp
from jax import jacfwd, jacrev

# f is a function with two parameters, the first of which, x,
# comprises 2 scalars and the second of which, y, comprises 3
# scalars.  It does just enough arithmetic to give it non-trivial
# second derivatives.
def f(x, y):
    xx,yy = x
    a,b,c = y;
    return a*xx*xx + b*xx*yy + c*yy*yy

# We want the "off-diagonal" parts of the Hessian of f.  I.e., the
# d2f/dx dy and/or d2f/dydx.  We expect these to return a 2x3
# "matrix" and a 3x2 "matrix", respectively.

d2fdydx = jacfwd(jacrev(f, argnums=1), argnums=0) # should return a 3x2 "matrix"
d2fdxdy = jacfwd(jacrev(f, argnums=0), argnums=1) # should return a 2x3 "matrix"

# Let's check those dimensions:
def check_dimens(m, expected, what):
    if len(m) != expected[0] or len(m[0]) != expected[1]:
        print(f"Uh oh. {what}: Expected something with shape {expected}, but got {m}")
    else:
        print("ok", what)

check_dimens(d2fdydx((1., 2.), (3., 4., 5.)), (3,2), "two tuples")
check_dimens(d2fdxdy((1., 2.), (3., 4., 5.)), (2,3), "two tuples")

check_dimens(d2fdydx((1., 2.), [3., 4., 5.]), (3,2), "tuple and list")
check_dimens(d2fdxdy((1., 2.), [3., 4., 5.]), (2,3), "tuple and list")

check_dimens(d2fdydx(jnp.array([1., 2.]), jnp.array([3., 4., 5.])), (3,2), "two arrays")
check_dimens(d2fdxdy(jnp.array([1., 2.]), jnp.array([3., 4., 5.])), (2,3), "two arrays")

check_dimens(d2fdydx(jnp.array([1., 2.]), [3., 4., 5.]), (3,2), "dydx array and list")
check_dimens(d2fdxdy(jnp.array([1., 2.]), [3., 4., 5.]), (2,3), "dxdy array and list") ### ???

check_dimens(d2fdydx([1., 2.], jnp.array([3., 4., 5.])), (3,2), "dydx list and array") ### ???
check_dimens(d2fdxdy([1., 2.], jnp.array([3., 4., 5.])), (2,3), "dxdy list and array")

It produces the following output:

drdws0134$ ./foo.py
/u/nyc/salmonj/.local/lib/python3.7/site-packages/jax/lib/xla_bridge.py:125: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
ok two tuples
ok two tuples
ok tuple and list
ok tuple and list
ok two arrays
ok two arrays
ok dydx array and list
Uh oh. dxdy array and list: Expected something with shape (2, 3), but got [DeviceArray([2., 0.], dtype=float32), DeviceArray([2., 1.], dtype=float32), DeviceArray([0., 4.], dtype=float32)]
Uh oh. dydx list and array: Expected something with shape (3, 2), but got [DeviceArray([2., 2., 0.], dtype=float32), DeviceArray([0., 1., 4.], dtype=float32)]
ok dxdy list and array
drdws0134$ 

The confusing thing is that when one (but not both!) of the arguments to the partial hessian (d2fdxdy or d2fdydx) is an jax.numpy array, the "shape" of the value returned by d2fdxdy is transposed from what's expected.

Wild guess: arrays are "leaf nodes" in the PyTree representation, and something confusing happens when a PyTree consisting of an array and a tuple/list is flattened and then unflattened internally in the derivative.

Further clarification would be greatly appreciated. Is there, perhaps a better way to write d2fdxdy so it's not susceptible to this? Or should I just be very careful not to mix and match arrays and python containers?

jakevdp commented 4 years ago

Thanks for the report. This is confusing, but I think it's understandable given JAX's model. When you differentiate with respect to a list, it will return a list. When you differentiate with respect to an array, it will return an array.

When you doubly differentiate with respect to a list and an array, it will return a list of arrays... no matter which order you put them in. This really is the only possibility, because JAX only supports simple data types (returning an array of dtype object containing lists of values is not an option).

I think the moral is that if you are mixing lists and arrays in differentiation and you want to treat the result as a tensor, you have to put some thought into how you construct that tensor from the outputs.

Maybe not very satisfying, but I hope that helps.