Closed PierreQuinton closed 2 months ago
I agree, based on this source: https://pytorch.org/docs/stable/generated/torch.linalg.pinv.html
However, I think you forgot the .solution
in your suggested code, so it should be:
raw_weights = torch.linalg.lstsq(matrix @ matrix.T, d).solution
Apparently, the try-catch is still necessary: matrices with values larger than like 1e18 lead to RuntimeError
when applying lstsq
.
Doing pseudo inverse then multiplying is equivalent (and less efficient) than solving a linear system.
In IMTL-G, we have the following code: https://github.com/TorchJD/torchjd/blob/ab9208f681c9b91a31fd667c5ca9c74db46a1d59/src/torchjd/aggregation/imtl_g.py#L45
It is equivalent to, and less efficient (and precise) than
If we do that, we do not need the try/catch anymore. Note that we need to be careful about the driver.