pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.39k stars 102 forks source link

Get Hessian as matrix #787

Open pop-cultural-reference opened 2 years ago

pop-cultural-reference commented 2 years ago

hessian gives me a tuple of tuples of tensors of varying dimensions, but I just want a matrix. I bet this is or will be an extremely common use case, so it would be great if, say, the docs showed a generalized way to aggregate the output into a matrix (it's not immediately clear to me whether obvious aggregations are valid)

samdow commented 2 years ago

@pop-cultural-reference could you share your code?

pop-cultural-reference commented 2 years ago

Maybe it's enough to just note I'm getting that shape because I'm using make_functional on a ModuleDict? (So params, which I'm taking the hessian w.r.t., is already a tuple of tensors of varying dimensions.)

samdow commented 2 years ago

@pop-cultural-reference Thanks for the context! Could you also help me understand what sort of matrix you're trying to get in the end? It's a little confusing to me since the shapes are going to be all different (and the dimensionality will be dependent on the which combination of parameters we're using for the second derivative), but I'm probably missing something

As a side note, in case it's unclear, the returned tuple is to give the second derivative of the outputs with the combination of each of the inputs. In other words, hessian(f)(params, x)[0][1] where params is your tuple of parameters for your model is the same as jacfwd(jacrev(f, argnums=0), argnums=1)(params0, params1, ..., x) Note that I've had to change the signature here in order for jacfwd and jacrev to distinguish between each of the parameters since argnums doesn't accept pytrees. Hopefully this conveys the idea but let me know if it doesn't!