Open shannon63 opened 4 years ago
tree_multimap
is how you can apply an element-wise operation to multiple pytrees, e.g., tree_multimap(lambda x, y: x + y, first_tree, second_tree)
(Side note: this is another good use-case for the tree-vectorizing transformation, once we finish it!)
Hello,
I would like to compute the inner product J Sigma J^T, where J is the Jacobian of a neural network and Sigma is a diagonal matrix of size p by p, where p is the number of network parameters, and the diagonal entries of Sigma are also parameters with the same PyTree structure as the network parameters wrt which I'm computing the Jacobians.
I know that I can compute J J^T easily (following the implementation here), but I'm having trouble computing J Sigma * J^T.
I tried to take advantage of the fact that instead of multiplying J or J^T by the diagonal matrix Sigma, I can vectorize the diagonal matrix Sigma into a p by 1 vector and then multiply each column of J by this p by 1 vector elementwise.
Following the trick explained here, I tried the following:
Unfortunately, this doesn't seem to work, and I'm getting the error message
TypeError: can't multiply sequence by non-int of type 'JVPTracer'
.What's the right way to do elementwise multiplication of two PyTrees for this setting?
Any help would be greatly appreciated!