Open johnsalmon opened 4 years ago
Thanks for the report. This is confusing, but I think it's understandable given JAX's model. When you differentiate with respect to a list, it will return a list. When you differentiate with respect to an array, it will return an array.
When you doubly differentiate with respect to a list and an array, it will return a list of arrays... no matter which order you put them in. This really is the only possibility, because JAX only supports simple data types (returning an array of dtype object containing lists of values is not an option).
I think the moral is that if you are mixing lists and arrays in differentiation and you want to treat the result as a tensor, you have to put some thought into how you construct that tensor from the outputs.
Maybe not very satisfying, but I hope that helps.
I'm not sure if this is a bug or a "feature". But even if it's a "feature", I think it deserves a mention in the "Sharp Bits" page.
I'm interested in the off-diagonal part of the Hessian of a function that takes two arguments. Here's some code.
It produces the following output:
The confusing thing is that when one (but not both!) of the arguments to the partial hessian (d2fdxdy or d2fdydx) is an jax.numpy array, the "shape" of the value returned by d2fdxdy is transposed from what's expected.
Wild guess: arrays are "leaf nodes" in the PyTree representation, and something confusing happens when a PyTree consisting of an array and a tuple/list is flattened and then unflattened internally in the derivative.
Further clarification would be greatly appreciated. Is there, perhaps a better way to write d2fdxdy so it's not susceptible to this? Or should I just be very careful not to mix and match arrays and python containers?