pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.59k stars 987 forks source link

Update auto_reg_nn.sample_mask_indices() to be default device-aware #3344

Closed cafletezbrant closed 8 months ago

cafletezbrant commented 8 months ago

Hi Pyro team, thank you for making such a useful and cool library. I encountered a small bug with an easy fix and wanted to share.

As described in my Pyro forum post, there is a device mismatch in auto_reg_nn.sample_mask_indices(). The line

 indices = torch.linspace(1, input_dim, steps=hidden_dim, device="cpu").to(
      torch.Tensor().device
)

creates tensors on CPU, even when torch.set_default_device('cuda') is used (I believe this is because torch.Tensor is an alias to torch.FloatTensor, which is not the same as torch.cuda.FloatTensor()) . Minimum working example (from Pyro docs):

import torch
import pyro
from pyro.nn import AutoRegressiveNN

torch.set_default_device('cuda')

x = torch.randn(100, 10)
print(x.device)
# cuda:0
print(torch.Tensor().device)
# cpu
print(torch.tensor(0.).device)
# cuda:0
arn = AutoRegressiveNN(10, [50], param_dims=[1])
p = arn(x)

The instantiation of a AutoRegressiveNN object will fail with the error

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

The proposed fix is to replace torch.Tensor().device with torch.tensor(0.0).device (lower case tensor; adding a simple value since torch.tensor() expects data). Then the object can be instantiated. This change is the sole element in this PR.

martinjankowiak commented 8 months ago

thanks @cafletezbrant

this is pretty old code...

wouldn't this be sufficient? torch.linspace(1, input_dim, steps=hidden_dim)

cafletezbrant commented 8 months ago

@martinjankowiak pretty old as in almost deprecated? Or just not recently updated?

Also yes, I just tested, your proposal also works, can update to that if you'd prefer.

martinjankowiak commented 8 months ago

yes please use the simpler version, thanks!

pretty old as in almost deprecated? Or just not recently updated?

not recently updated and therefore oldish pytorch idioms

martinjankowiak commented 8 months ago

does arn.to(...) work as expected?

cafletezbrant commented 8 months ago

Yes, arn.to() works as expected:

arn.to('cpu')
next(arn.parameters()).is_cuda
# False
p = arn(x.cpu())
p.device
# device(type='cpu')
arn.to('cuda')
next(arn.parameters()).is_cuda
# True
p = arn(x)
p.device
# device(type='cuda', index=0)
p[0, 0:5]
# tensor([-0.2749,  0.0823,  0.1205, -0.1107,  0.1880], device='cuda:0',
#       grad_fn=<SliceBackward0>)

I've pushed the requested simpler version. I was asking about age because if this is relatively unused code, I might expect to stub my toe a few more times, which might turn into one or more additional PRs.

martinjankowiak commented 8 months ago

not sure what your goals are but there are certainly more up-to-date normalizing flows libraries out there, some of which have some amount of pyro integration, see e.g. https://github.com/pyro-ppl/pyro/blob/dev/pyro/contrib/zuko.py

cafletezbrant commented 8 months ago

Ah interesting, is that a more recommended way to do things [1]? I was just trying to test out whether an NF would help my model fit (i.e. I am not sure if it will), which is why I was originally trying to use an AutoGuide. I suppose the way forward would be to simply write a guide using e.g. Zuko for the parameters I'm trying to estimate via NF and add that to my AutoGuideList?

[1] Just to be clear, I meant no criticism of the state of affairs of this code base, just that if it was less maintained than other parts, that I might be posting here again.

martinjankowiak commented 8 months ago

i think using the machinery in pyro is probably a reasonable place to start but if you want to explore a more diverse and/or more recent set of flows it may be a good idea to explore other pytorch-based flow libraries like zuko

cafletezbrant commented 8 months ago

Got it, thanks for the pointer. I'll explore the built-in work first and see where that goes.

martinjankowiak commented 8 months ago

looks like you deleted before i could merge

cafletezbrant commented 8 months ago

Sorry, brainfart! I will fix on Monday