I have changed the numpy methods to torch and played around for a bit.
If I get it right, pytorch lightning is supposed to handle devices without having to specify anything. This does not seem to happen here, i.e. when I call torch.linspace() in the init of a lit module it does not go to cuda:0 but instead to cpu. So i guess everything about our model is on cpu and everything that comes from the dataloader is on cuda:0. :(
I have changed the numpy methods to torch and played around for a bit. If I get it right, pytorch lightning is supposed to handle devices without having to specify anything. This does not seem to happen here, i.e. when I call torch.linspace() in the init of a lit module it does not go to cuda:0 but instead to cpu. So i guess everything about our model is on cpu and everything that comes from the dataloader is on cuda:0. :(