This means that the module is not compatible with use on a CUDA device.
This PR simply to adds the device to torch.eye using the same device as one of the parameters in LULinear (lower_entries).
Cheers,
Michael
Current error
Python 3.8.8 (default, Feb 24 2021, 21:46:12)
[GCC 7.3.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from nflows.transforms import LULinear
>>> lu = LULinear(4)
>>> lu.weight_inverse()
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 1.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 1.0000]], grad_fn=<TriangularSolveBackward>)
>>> lu.to('cuda:1')
LULinear()
>>> lu.weight_inverse()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/michael/git_repos/nflows/nflows/transforms/lu.py", line 110, in weight_inverse
lower_inverse, _ = torch.triangular_solve(
RuntimeError: Expected b and A to be on the same device, but found b on cpu and A on cuda:1 instead.
Fixed version
Python 3.8.8 (default, Feb 24 2021, 21:46:12)
[GCC 7.3.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from nflows.transforms import LULinear
>>> lu = LULinear(4)
>>> lu.weight_inverse()
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 1.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 1.0000]], grad_fn=<TriangularSolveBackward>)
>>> lu.to('cuda:1')
LULinear()
>>> lu.weight_inverse()
tensor([[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]], device='cuda:1', grad_fn=<TriangularSolveBackward>)
Hi bayesiains,
The current version of
LULinear.weight_inverse
callstorch.eye
without specifying a device: https://github.com/bayesiains/nflows/blob/639c3a771d57c29a27c307140cc94a1008ee9f55/nflows/transforms/lu.py#L109This means that the module is not compatible with use on a CUDA device.
This PR simply to adds the device to
torch.eye
using the same device as one of the parameters inLULinear
(lower_entries
).Cheers, Michael
Current error
Fixed version