google / objax

Apache License 2.0
769 stars 77 forks source link

objax.Jacobian and objax.Hessian similar to objax.Grad #234

Open gkaissis opened 2 years ago

gkaissis commented 2 years ago

It would be amazing to have a direct op to compute jacobians and hessians w.r.t. model parameters like we have for objax.Grad. I suppose that these would require an unreduced loss value (i.e. raise an exception if the loss value is scalar). The Jacobian would then essentially be a stand-in for "per-sample" gradients. Understandably, the Hessian is probably not tractable from a memory point-of-view for NNs. However, it would still make sense to add if the op can be jitted into some function to e.g. compute the condition number of the Hessian and never ends up materialising the matrix itself. What's the developers' opinion on this?

Thank you and congratulations on the amazing package. Objax is making transitioning from PyTorch very easy!

AlexeyKurakin commented 2 years ago

JAX provides Jacobian and Hessian functions, like the following: https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html#jax-hessian https://jax.readthedocs.io/en/latest/_autosummary/jax.jacfwd.html#jax.jacfwd https://jax.readthedocs.io/en/latest/_autosummary/jax.jacrev.html#jax.jacrev

So we can add Objax wrappers for these operations similarly to objax.Grad. I would prefer to avoid putting anything more complicated than this.

You're welcome to make a pull request for these ops, it should be mostly similar to objax.Grad

AlexeyKurakin commented 2 years ago

I'm actually going to add support for Jacobian and Hessian pretty soon

AlexeyKurakin commented 2 years ago

Hessian and Jacobian are merged into main branch. I will also add Hessian vector product and Jacobian vector product some time soon

gkaissis commented 2 years ago

Amazing!