bayesiains / nflows

Normalizing flows in PyTorch
MIT License
849 stars 118 forks source link

Fix LULinear.weight_inverse with CUDA #38

Closed mj-will closed 2 years ago

mj-will commented 3 years ago

Hi bayesiains,

The current version of LULinear.weight_inverse calls torch.eye without specifying a device: https://github.com/bayesiains/nflows/blob/639c3a771d57c29a27c307140cc94a1008ee9f55/nflows/transforms/lu.py#L109

This means that the module is not compatible with use on a CUDA device.

This PR simply to adds the device to torch.eye using the same device as one of the parameters in LULinear (lower_entries).

Cheers, Michael

Current error

Python 3.8.8 (default, Feb 24 2021, 21:46:12) 
[GCC 7.3.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from nflows.transforms import LULinear
>>> lu = LULinear(4)
>>> lu.weight_inverse()
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 1.0000]], grad_fn=<TriangularSolveBackward>)
>>> lu.to('cuda:1')
LULinear()
>>> lu.weight_inverse()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/michael/git_repos/nflows/nflows/transforms/lu.py", line 110, in weight_inverse
    lower_inverse, _ = torch.triangular_solve(
RuntimeError: Expected b and A to be on the same device, but found b on cpu and A on cuda:1 instead.

Fixed version

Python 3.8.8 (default, Feb 24 2021, 21:46:12) 
[GCC 7.3.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from nflows.transforms import LULinear
>>> lu = LULinear(4)
>>> lu.weight_inverse()
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 1.0000]], grad_fn=<TriangularSolveBackward>)
>>> lu.to('cuda:1')
LULinear()
>>> lu.weight_inverse()
tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]], device='cuda:1', grad_fn=<TriangularSolveBackward>)