google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.75k stars 2.71k forks source link

Inner product of PyTree Jacobians #4189

Open shannon63 opened 4 years ago

shannon63 commented 4 years ago

Hello,

I would like to compute the inner product J Sigma J^T, where J is the Jacobian of a neural network and Sigma is a diagonal matrix of size p by p, where p is the number of network parameters, and the diagonal entries of Sigma are also parameters with the same PyTree structure as the network parameters wrt which I'm computing the Jacobians.

I know that I can compute J J^T easily (following the implementation here), but I'm having trouble computing J Sigma * J^T.

I tried to take advantage of the fact that instead of multiplying J or J^T by the diagonal matrix Sigma, I can vectorize the diagonal matrix Sigma into a p by 1 vector and then multiply each column of J by this p by 1 vector elementwise.

Following the trick explained here, I tried the following:

from jax.api import jvp
from jax.api import vjp
from jax.api import jacobian
from jax.tree_util import tree_map

g = lambda arg: Sigma * arg
tree_map(g, Sigma)

def delta_vjp_jvp(delta):
  def delta_vjp(delta):
    _vjp = vjp(predict_fn, params_mean)[1](delta)
    Sigma_jvp = tree_map(g, _vjp)
    return _vjp
  return jvp(predict_fn, (params_mean,), delta_vjp(delta))[1]

mat = jacobian(delta_vjp_jvp)(fx_dummy)

Unfortunately, this doesn't seem to work, and I'm getting the error message

TypeError: can't multiply sequence by non-int of type 'JVPTracer'.

What's the right way to do elementwise multiplication of two PyTrees for this setting?

Any help would be greatly appreciated!

shoyer commented 4 years ago

tree_multimap is how you can apply an element-wise operation to multiple pytrees, e.g., tree_multimap(lambda x, y: x + y, first_tree, second_tree)

(Side note: this is another good use-case for the tree-vectorizing transformation, once we finish it!)