TorchJD / torchjd

Library for Jacobian descent with PyTorch. It enables optimization of neural networks with multiple losses (e.g. multi-task learning).
https://torchjd.org
MIT License
130 stars 0 forks source link

Fix types #114

Closed ValerianRey closed 2 weeks ago

ValerianRey commented 2 weeks ago

I think one of the best changes is get_vjp. It was not clear at all what v was supposed to be. In fact, jac_outputs is a Sequence[Tensor] and v is also a Sequence[Tensor], but of each tensor has 1 less dimension. So v is what we refer to as grad_outputs. I thus renamed v to grad_outputs.