Closed eyalbetzalel closed 4 years ago
Hey @eyalbetzalel, thanks for the bug report!
Are you only hitting this in the PixelSnail model? It seems the positional encodings were never being transferred to GPU in the model. I've fixed that 9497c7116811c7d0a782a4593a177add1c43c119.
Please let me know if you still see the issue.
Hey @eyalbetzalel, thanks for the bug report!
Are you only hitting this in the PixelSnail model? It seems the positional encodings were never being transferred to GPU in the model. I've fixed that 9497c71.
Please let me know if you still see the issue.
I think that the problem still happens because of trainer.py init function (line 26) :
def __init__(self, model, loss_fn, optimizer, train_loader, eval_loader, lr_scheduler=None, log_dir='/tmp/runs', save_checkpoint_epochs=1, device=torch.device('cpu')):
The init function of the trainer.py always transfer the model to the cpu - and there isn't a way to config it by the args to cuda.
After changing it to cuda - it seems to run the model on the gpu.
The init function of the trainer.py always transfer the model to the cpu
Yea we default the device
argument to cpu
. If no device
argument is passed to Trainer
, then by default, the model will run on CPU.
there isn't a way to config it by the args to cuda.
Does explicitly passing the device as an argument to your trainer not work? E.g:
device = torch.device('cuda')
model_trainer = trainer(..., device=device)
model_trainer.interleaved_train_and_eval(n_epochs)
Alternatively, you could use the colab_utils.get_device() function to automatically handle cpu
vs cuda
for you, e.g.:
model_trainer = trainer(..., device=colab_utils.get_device())
model_trainer.interleaved_train_and_eval(n_epochs)
@eyalbetzalel going to close this out as it seems the issue is resolved. Please reopen if this is not the case.
Hi!
In trainer.py the defult settings are for device=torch.device('cpu')).
when I change it to 'cuda' I get this error :
RuntimeError: All input tensors must be on the same device. Received cpu and cuda:0
Do you know what should I do in order to fix this?
Thanks!