Closed arturbekasov closed 4 years ago
Commited some changes that fix ConditionalDiagonalNormal
and StandardNormal
. To test:
import torch
from nflows.distributions import ConditionalDiagonalNormal
from nflows.distributions import StandardNormal
device = torch.device('cuda:0')
d = ConditionalDiagonalNormal(shape=(5,))
context = torch.randn(1, 10).to(device)
samples = d.sample(num_samples=10, context=context)
assert samples.shape == (1, 10, 5)
assert samples.device == device
d = StandardNormal(shape=(5,)).to(device)
samples = d.sample(num_samples=10)
assert samples.shape == (10, 5)
assert samples.device == device
d = StandardNormal(shape=(5,))
context = torch.randn(1, 10).to(device)
samples = d.sample(num_samples=10, context=context)
assert samples.shape == (1, 10, 5)
assert samples.device == device
One problem with using a dummy buffer in StandardNormal
is that it will show-up in state_dict
. One consequence is that existing checkpoints won't load without additional work. This would be fixed by using a persistent=False
flag, but it was merged to PyTorch about a month ago, so will take a while to get to stable.
I've also discovered that sampling is completely broken in DiagonalNormal
. Would like to fix in this PR as well.
1.) Non-persistent buffers are now in stable: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer
2.) Rather than use a dummy buffer, why not put _log_z
into a buffer? That way _log_prob
will also easily work on a GPU.
Great: thanks for letting me know, James. This simplifies things a lot. Good point about ‘_log_z’: don’t see any downsides to making it a (non-persistent) buffer. I’ll try to find some time to finish this PR.
On the sampling-on-gpu
branch, I tried
base_dist = StandardNormal(shape=(5,))
samples = base_dist.sample(num_samples=10)
print(samples.device)
with output cpu
flow = Flow(transform, base_dist).to(device)
samples = base_dist.sample(num_samples=10)
print(samples.device)
print([p.is_cuda for p in flow.parameters()])
with output cuda:0
and all True
s
but when I try to run my model trainloss = -flow.log_prob(inputs=X).mean()
I get RuntimeError: expected device cpu but got device cuda:0
Polished the changes from @janfb (thanks!). Made more edits to infer the device when creating tensors -- hopefully this fixed the majority of GPU-related problems. Unfortunately I don't know of an easy way to add automatic tests for such device-related problems, unless we want to make unit tests only runnable on a machine with a GPU, which I don't want to do.
@andersdot If you could try running your code again -- that'd be great. If it doesn't work still, could you tell us what transform you're using exactly? Thanks.
Decided not to keep this blocked because of broken sampling in DiagonalNormal. Removed the broken code for now -- can come back to this later (or when someone flags this up).
The tests pass, and most changes are either trivial, or have been looked at by someone else, so will merge this.
PR for ongoing work for fixing #8.