pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Simultaneous computation of per-sample gradient and per-batch gradient #1012

Open SJShin-AI opened 1 year ago

SJShin-AI commented 1 year ago

Hi. Thanks for providing functorch, one of most effective lib. Actually i am working on my code which needs an access into both per-sample gradient in mini-batch and per-batch gradient, which we call as gradient in normal.

Having said that, i think that functorch now needs forward-pass twice to get both gradients. In other words, we cannot get per-batch gradient from the per-sample gradient by mean calculation, which was checked by me with empirical calculation in torch. It leads me to forward pass twice for computing each information.

https://backpack.pt/ I think this library also provides very similar utility like per-sample gradients. Also it provides multiple access into various information. (e.g. normal gradient, batch gradient, gradient variance, gauss-newton diagonal), which seems to be not possible with functorch.

It would be grateful if i have any way to get access into both 1) per-sample gradient in mini-batch and 2) per-batch gradient with computationally efficient way (It would be also helpful if we can also compute other information (e.g. hessian, gradient variance...) . thanks.

samdow commented 1 year ago

In other words, we cannot get per-batch gradient from the per-sample gradient by mean calculation, which was checked by me with empirical calculation in torch.

Could you try summing the per sample gradients together? That should be the same and if it's not, it would be great to have a repro since that's a pretty nasty issue

It would be also helpful if we can also compute other information (e.g. hessian, gradient variance...)

For gradient variance, I think the best way would be to compute from the per sample gradients. For hessians it gets a little more complicated but if your per sample gradients match the jacobian, we could compute the hessian and produce the jacobian along the way, like this:

def jacrev_with_aux(f, argnums=0):
  def wrapper(*args)
    out = jacrev(f, argnums)(*args)
    return out, out
  return wrapper

def hessian_with_jacrev(f, argnums=0):
  hessian, jac = jacfwd(jacrev_with_aux(f, argnums), argnums, has_aux=True)
  return hessian, jac

Let me know if that helps!