Open pop-cultural-reference opened 2 years ago
@pop-cultural-reference could you share your code?
Maybe it's enough to just note I'm getting that shape because I'm using make_functional
on a ModuleDict
? (So params
, which I'm taking the hessian w.r.t., is already a tuple of tensors of varying dimensions.)
@pop-cultural-reference Thanks for the context! Could you also help me understand what sort of matrix you're trying to get in the end? It's a little confusing to me since the shapes are going to be all different (and the dimensionality will be dependent on the which combination of parameters we're using for the second derivative), but I'm probably missing something
As a side note, in case it's unclear, the returned tuple is to give the second derivative of the outputs with the combination of each of the inputs. In other words, hessian(f)(params, x)[0][1]
where params is your tuple of parameters for your model is the same as jacfwd(jacrev(f, argnums=0), argnums=1)(params0, params1, ..., x)
Note that I've had to change the signature here in order for jacfwd and jacrev to distinguish between each of the parameters since argnums doesn't accept pytrees. Hopefully this conveys the idea but let me know if it doesn't!
hessian
gives me a tuple of tuples of tensors of varying dimensions, but I just want a matrix. I bet this is or will be an extremely common use case, so it would be great if, say, the docs showed a generalized way to aggregate the output into a matrix (it's not immediately clear to me whether obvious aggregations are valid)