bayesiains / nflows

Normalizing flows in PyTorch
MIT License
845 stars 118 forks source link

Fix sampling from distributions on a GPU #9

Closed arturbekasov closed 4 years ago

arturbekasov commented 4 years ago

PR for ongoing work for fixing #8.

arturbekasov commented 4 years ago

Commited some changes that fix ConditionalDiagonalNormal and StandardNormal. To test:

import torch

from nflows.distributions import ConditionalDiagonalNormal
from nflows.distributions import StandardNormal

device = torch.device('cuda:0')

d = ConditionalDiagonalNormal(shape=(5,))
context = torch.randn(1, 10).to(device)
samples = d.sample(num_samples=10, context=context)
assert samples.shape == (1, 10, 5)
assert samples.device == device

d = StandardNormal(shape=(5,)).to(device)
samples = d.sample(num_samples=10)
assert samples.shape == (10, 5)
assert samples.device == device

d = StandardNormal(shape=(5,))
context = torch.randn(1, 10).to(device)
samples = d.sample(num_samples=10, context=context)
assert samples.shape == (1, 10, 5)
assert samples.device == device

One problem with using a dummy buffer in StandardNormal is that it will show-up in state_dict. One consequence is that existing checkpoints won't load without additional work. This would be fixed by using a persistent=False flag, but it was merged to PyTorch about a month ago, so will take a while to get to stable.

I've also discovered that sampling is completely broken in DiagonalNormal. Would like to fix in this PR as well.

JamesRitchie commented 4 years ago

1.) Non-persistent buffers are now in stable: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer 2.) Rather than use a dummy buffer, why not put _log_z into a buffer? That way _log_prob will also easily work on a GPU.

arturbekasov commented 4 years ago

Great: thanks for letting me know, James. This simplifies things a lot. Good point about ‘_log_z’: don’t see any downsides to making it a (non-persistent) buffer. I’ll try to find some time to finish this PR.

andersdot commented 4 years ago

On the sampling-on-gpu branch, I tried

base_dist = StandardNormal(shape=(5,))
samples = base_dist.sample(num_samples=10)
print(samples.device)

with output cpu

flow = Flow(transform, base_dist).to(device)
samples = base_dist.sample(num_samples=10)
print(samples.device)
print([p.is_cuda for p in flow.parameters()])

with output cuda:0 and all Trues

but when I try to run my model trainloss = -flow.log_prob(inputs=X).mean() I get RuntimeError: expected device cpu but got device cuda:0

arturbekasov commented 4 years ago

Polished the changes from @janfb (thanks!). Made more edits to infer the device when creating tensors -- hopefully this fixed the majority of GPU-related problems. Unfortunately I don't know of an easy way to add automatic tests for such device-related problems, unless we want to make unit tests only runnable on a machine with a GPU, which I don't want to do.

@andersdot If you could try running your code again -- that'd be great. If it doesn't work still, could you tell us what transform you're using exactly? Thanks.

Decided not to keep this blocked because of broken sampling in DiagonalNormal. Removed the broken code for now -- can come back to this later (or when someone flags this up).

The tests pass, and most changes are either trivial, or have been looked at by someone else, so will merge this.