bayesiains / nflows

Normalizing flows in PyTorch
MIT License
845 stars 118 forks source link

Solve torch.triangular_solve deprecation warning #57

Closed francesco-vaselli closed 2 years ago

francesco-vaselli commented 2 years ago

From PyTorch 1.11 docs: link

UserWarning: torch.triangular_solve is deprecated in favor of torch.linalg.solve_triangular and will be removed in a future PyTorch release. torch.linalg.solve_triangular has its arguments reversed and does not return a copy of one of the inputs. X = torch.triangular_solve(B, A).solution should be replaced with X = torch.linalg.solve_triangular(A, B).

This PR swaps in the new torch.linalg.solve_triangular() (see docs) with the arguments correctly swapped. I tested it for sample generation in my use case and everything works fine without any UserWarning!

Cheers, Francesco

francesco-vaselli commented 2 years ago

Hey @arturbekasov , sorry for the ping, but this seems a trivial contribution and would ensure compatibility with future versions of Pytorch.

Please let me know if there is something else that needs to be done before a merge can happen! Thank you for your work on this useful package, Cheers, Francesco

arturbekasov commented 2 years ago

Hi Francesco,

Thanks for the ping, and sorry for the delay.

The change LGTM -- thank you for taking the time to fix the warning.

Artur

arturbekasov commented 2 years ago

Hi Francesco,

Thanks for the ping, and sorry for the delay.

The change LGTM -- thank you for taking the time to fix the warning.

Artur