tonyduan / normalizing-flows

Neural Spline Flow, RealNVP, Autoregressive Flow, 1x1Conv in PyTorch.
MIT License
271 stars 38 forks source link

Add GPU support for model and flows #6

Closed Baukebrenninkmeijer closed 1 month ago

Baukebrenninkmeijer commented 4 years ago

I added GPU support for all forward/inverse passes of the flows that support it and NormalizingFlowModel. The tensors are moved to the same device as the input tensor (so wherever x or z are). For now, the result of sampling will be moved to the CPU, since I expect that to most often be the use-case. Let me know if you have any other suggestions.

To get it working, put your x, prior and model on the GPU:

device = 'cuda' if torch.cuda.is_available() else 'cpu'
prior = MultivariateNormal(torch.zeros(1).to(device), torch.eye(1).to(device))
model = NormalizingFlowModel(prior, flows).to(device)
x = torch.Tensor(gen_data(args.n)).to(device)

From there on, it should work without any problems.

I didn't change the examples, but can make them use GPU is possible as well. Let me know your preference.

Baukebrenninkmeijer commented 4 years ago

Moves device location to tensor initialization instead of after. Removed the cpu as default location for sampling, and is now taking the same location as the prior. Should be good now :).

Baukebrenninkmeijer commented 4 years ago

The parameters in OneByOneConv are not registered correctly as parameters, and are not moved to the correct device when we call model.to(device). So now i'm calling the .to in the forward and backward call individually.

tonyduan commented 4 years ago

Thanks for making the requested changes.

The parameters in OneByOneConv are not registered correctly as parameters, and are not moved to the correct device when we call model.to(device). So now i'm calling the .to in the forward and backward call individually.

I think we can fix this by replacing the line:

self.P = torch.tensor(P, dtype = torch.float)

with:

self.P = nn.Parameter(torch.tensor(P, dtype = torch.float), requires_grad = False)

Could you take a stab at this and let me know if this fixes the issue?

Baukebrenninkmeijer commented 4 years ago

Yes, i'll have a look when I have time. Hopefully, somewhere later this week.