jonbarron / robust_loss_pytorch

A pytorch port of google-research/google-research/robust_loss/
Apache License 2.0
656 stars 88 forks source link

added torch.cuda.set_device for GPU devices #3

Closed relh closed 5 years ago

relh commented 5 years ago

Hi, really like your paper and the repo.

If someone is using this repo on a GPU that isn't cuda:0, and they haven't run torch.cuda.set_device(X) where X is their GPU, then all of the torch.as_tensor(Y) calls in util.py will be put on the wrong device and cause an error. Specifically this error:

File "/home/relh/solar/robust_loss_pytorch/distribution.py", line 201, in nllfun assert (scale >= 0).all() RuntimeError: CUDA error: an illegal memory access was encountered

Since the code in adaptive.py is structured around accepting a torch device, and PyTorch has 3 ways of passing a torch device (https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device), this tiny fix checks for any of the 3, checks that they're a cuda device, and if so sets the device appropriately.

I'm no PyTorch guru, and I'm not sure if this a preferred method, but it works for me. It's not checking if torch.cuda.is_available() because I think that should crash a program.

On the other hand, maybe it shouldn't be the role of a library to call torch.cuda.set_device.

jonbarron commented 5 years ago

Looks great, thanks! I'm no pytorch guru either (I learned pytorch while writing this repo) so if this fixes a crash it's good by me.