TorchJD / torchjd

Library for Jacobian descent with PyTorch. It enables optimization of neural networks with multiple losses (e.g. multi-task learning).
https://torchjd.org
MIT License
151 stars 0 forks source link

Improve IMTL-G implementation #132

Closed PierreQuinton closed 2 months ago

PierreQuinton commented 2 months ago

Doing pseudo inverse then multiplying is equivalent (and less efficient) than solving a linear system.

In IMTL-G, we have the following code: https://github.com/TorchJD/torchjd/blob/ab9208f681c9b91a31fd667c5ca9c74db46a1d59/src/torchjd/aggregation/imtl_g.py#L45

It is equivalent to, and less efficient (and precise) than

raw_weights = torch.linalg.lstsq(matrix @ matrix.T, d)

If we do that, we do not need the try/catch anymore. Note that we need to be careful about the driver.

ValerianRey commented 2 months ago

I agree, based on this source: https://pytorch.org/docs/stable/generated/torch.linalg.pinv.html Screenshot from 2024-09-09 20-06-31

However, I think you forgot the .solution in your suggested code, so it should be:

raw_weights = torch.linalg.lstsq(matrix @ matrix.T, d).solution
ValerianRey commented 2 months ago

Apparently, the try-catch is still necessary: matrices with values larger than like 1e18 lead to RuntimeError when applying lstsq.