pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.48k stars 982 forks source link

GPU support for normalizing flows #2761

Open gshartnett opened 3 years ago

gshartnett commented 3 years ago

I am interested in using the Pyro implementation of normalizing flows for my research. However, I cannot find anywhere in the docs instructions on how to enable GPU support. The example page makes no mention of GPUs, and if I modify that code from

dataset = torch.tensor(X, dtype=torch.float)
...
spline_transform = T.spline_coupling(2, count_bins=16)

to

dataset = torch.tensor(X, dtype=torch.float).to(device).to('cuda')
...
spline_transform = T.spline_coupling(2, count_bins=16).to('cuda')

and attempt to run, I get this error RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!. How should I move the TransformedDistributions class entirely onto the GPU device? Tagging @stefanwebb since he is the person responsible for normalizing flow support according to this issue.

fritzo commented 3 years ago

As a quick workaround, you could set cuda as default before creating your spline_coupling object

torch.set_default_tensor_type(torch.cuda.FloatTensor)
spline_transform = T.spline_coupling(...)

Another not-ideal workaround is to save-and-load with map_location:

torch.save(any_python_object, "temp_file.pt")
any_python_object = torch.load("temp_file", map_location="cuda:0")
giovp commented 3 years ago

Hi @fritzo , I tried the first solution (torch.set_default_tensor_type(torch.cuda.FloatTensor)) and it doesn't seem to work. I get this error when using torch.DataLoaders and pytroch_lightning.LightningDataModule

RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'

Is there some other workaround to try out? Also, what's the bottleneck here? Looking at the traceback it fails at log_prob() evaluation.

Details ``` 200 z = posterior.sample() 201 log_qzx = posterior.log_prob(z) --> 202 log_pz = prior_trans.log_prob(z) 203 204 kl = log_pz - log_qzx.sum(-1) ~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/distributions/transformed_distribution.py in log_prob(self, value) 141 y = value 142 for transform in reversed(self.transforms): --> 143 x = transform.inv(y) 144 event_dim += transform.domain.event_dim - transform.codomain.event_dim 145 log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y), ~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/distributions/transforms.py in __call__(self, x) 340 def __call__(self, x): 341 for part in self.parts: --> 342 x = part(x) 343 return x 344 ~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/distributions/transforms.py in __call__(self, x) 247 def __call__(self, x): 248 assert self._inv is not None --> 249 return self._inv._inv_call(x) 250 251 def log_abs_det_jacobian(self, x, y): ~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/distributions/transforms.py in _inv_call(self, y) 159 if y is y_old: 160 return x_old --> 161 x = self._inverse(y) 162 self._cached_x_y = x, y 163 return x ~/miniconda3/envs/torch/lib/python3.8/site-packages/pyro/distributions/transforms/permute.py in _inverse(self, y) 91 Inverts y => x. 92 """ ---> 93 return y.index_select(self.dim, self.inv_permutation) 94 95 def log_abs_det_jacobian(self, x, y): ```
fritzo commented 3 years ago

Hi @gshartnett, it's difficult to diagnose where the stray cpu tensor is coming from without actually diving into a debugger. I'd recommend running under pdb and inspecting each tensor's device. You might also try updating PyTorch, since they seem to be getting more consistent over time.

One thing you could try is to use torch.save and torch.load, but you'd only be able to do that once before training:

x = my_complex_data_structure()
torch.save(x, "temp.pt")
x = torch.load("temp.pt", map_location="cuda:0")
giovp commented 3 years ago

thanks for the very prompt reply @fritzo , I was able to send some flows to cuda with self.to(device).

E.g.

AffineAutoregressive(
    AutoRegressiveNN(
        latent_dim,
        [hidden_units for _ in range(n_hidden)],
        skip_connections=True,
    )
).to(device)

Amongst the one I tried, this also works for BatchNorm and AffineCoupling but it does not work for Permute (which returns an error like "permute does not have to() method"). From very high level inspection, Permute is the only one inheriting from torch.distributions.transforms.Transform instead of pyro.distributions.torch_transform.TransformModule amongst the one I tried (the latter inheriting from nn.Module). Might be off track, but anyway wanted to report this.

EDIT: I was wrong, subclassing with TransformModule doesn't solve the problem. It seems to be in

return y.index_select(self.dim, self.inv_permutation)

SOLUTION: simply send to device the permutation index

Permute(torch.LongTensor(perm).to(device))

thanks again for the help!