tksmatsubara / symplectic-adjoint-method

Code for "Symplectic Adjoint Method for Exact Gradient of Neural ODE with Minimal Memory," NeurIPS, 2021.
14 stars 3 forks source link

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! #2

Closed yaomz16 closed 2 years ago

yaomz16 commented 2 years ago

I'm trying to use odeint_symplectic_adjoint(func, y0, t) and encountered the following issue:

Traceback (most recent call last): File "train.py", line 147, in main() File "train.py", line 61, in main train_one_epoch(func, optimizer, epoch, is_train=True) File "train.py", line 78, in train_one_epoch train_one_epoch_one_sampling(func, optimizer, epoch, iter_sampling, sampling_time, is_train) File "train.py", line 119, in train_one_epoch_one_sampling w_and_T_pred = odeint(func, w_and_T_0, time).cuda() # size: [2, 1, 2, numpts, numpts], [timesteps(t_0&t_out), batch_size, channels(w&T), H, W] File "/jet/home/mingzeya/tools/miniconda3/envs/abnn/lib/python3.7/site-packages/torch_symplectic_adjoint-0.0.1-py3.7.egg/torch_symplectic_adjoint/_impl/symplectic_adjoint.py", line 83, in odeint_symplectic_adjoint File "/jet/home/mingzeya/tools/miniconda3/envs/abnn/lib/python3.7/site-packages/torch_symplectic_adjoint-0.0.1-py3.7.egg/torch_symplectic_adjoint/_impl/symplectic_adjoint.py", line 23, in forward File "/jet/home/mingzeya/tools/miniconda3/envs/abnn/lib/python3.7/site-packages/torch_symplectic_adjoint-0.0.1-py3.7.egg/torch_symplectic_adjoint/_impl/integrators/solvers.py", line 34, in integrate File "/jet/home/mingzeya/tools/miniconda3/envs/abnn/lib/python3.7/site-packages/torch_symplectic_adjoint-0.0.1-py3.7.egg/torch_symplectic_adjoint/_impl/integrators/rk_common.py", line 181, in _advance File "/jet/home/mingzeya/tools/miniconda3/envs/abnn/lib/python3.7/site-packages/torch_symplectic_adjoint-0.0.1-py3.7.egg/torch_symplectic_adjoint/_impl/integrators/rk_common.py", line 237, in _adaptive_step File "/jet/home/mingzeya/tools/miniconda3/envs/abnn/lib/python3.7/site-packages/torch_symplectic_adjoint-0.0.1-py3.7.egg/torch_symplectic_adjoint/_impl/integrators/rk_common.py", line 76, in _runge_kutta_step File "/jet/home/mingzeya/tools/miniconda3/envs/abnn/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/jet/home/mingzeya/tools/miniconda3/envs/abnn/lib/python3.7/site-packages/torch_symplectic_adjoint-0.0.1-py3.7.egg/torch_symplectic_adjoint/_impl/integrators/misc.py", line 206, in forward File "/jet/home/mingzeya/tools/miniconda3/envs/abnn/lib/python3.7/site-packages/torch_symplectic_adjoint-0.0.1-py3.7.egg/torch_symplectic_adjoint/_impl/integrators/misc.py", line 312, in _nextafter RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking arugment for argument other in method wrapper_nextafter)

I'm pretty sure I made everything on gpu by printingprint(next(func.parameters()).device, y0.device, t.device) and got cuda:0 cuda:0 cuda:0. I also added two print function in the forward part of func as follows to make sure everything is on gpu:

class model(nn.Module): .... def forward(self, t, inp): print(t.device, inp.device) out = self.conv(...) ... print("*****",out.device) return out

and found the two print function work for 4 times before getting the above RuntimeError:

cuda:0 cuda:0 cuda:0 cuda:0 cuda:0 cuda:0 cuda:0 cuda:0 cuda:0 cuda:0 cuda:0 cuda:0 cuda:0 cuda:0 cuda:0 cuda:0 cuda:0 cuda:0

I wonder if there's any line which drops the model or some intermediate tensor off from gpu. Thanks!

tksmatsubara commented 2 years ago

Thank you for your report!

After fixing the issue #1, I confirmed that the code on this repository works as expected and reproduces the results of the conference proceeding.

Could you please check if the error occurs with other functions (e.g., odeint and odeint_adjoint) or the original torchdiffeq v0.1.1? While I omitted some functions for a fair performance comparison, odeint and odeint_adjoint remain almost identical to the original.

yaomz16 commented 2 years ago

Hi, I tried odeint and odeint_adjoint and found similar error occured, but the odeint_adjoint from original torchdiffeq (I used v0.2.2) worked normally.

tksmatsubara commented 2 years ago

How about the original torchdiffeq v0.1.1? Many features have been added and modified from v0.1.1 to v0.2.2.

yaomz16 commented 2 years ago

Hi, I confirmed that I can run my code normally with torchdiffeq v0.1.1, without any modificiation on my code

tksmatsubara commented 2 years ago

Are you using a newer version of PyTorch? I have only tested my code with PyTorch v1.7.1. Newer versions of PyTorch check if the arguments of nextafter are on the same device, but PyTorch v1.7.1 does not.

I forked the project from the latest torchdiffeq. Its version number was v0.1.1, but I think it was actually older than the released version of v0.1.1.

I imported the following change in the original torchdiffeq, and confirmed that it apparently works with PyTorch v1.10.1. https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/misc.py#L180-L188

Could you please use PyTorch v1.7.1 or (although I cannot guarantee it will work properly) import the above changes?

yaomz16 commented 2 years ago

Yes I was using torch 1.9. I created a new conda env with python 3.7.3 and pytorch 1.7.1 installed, and now the problem is solved, my training script is running. Thanks!

yaomz16 commented 2 years ago

A small suggestion is that maybe you can modify the nextafter part of your code so that it would be compatible with newer version of pytorch :-)

tksmatsubara commented 2 years ago

Thank you for your suggestion. I created a new branch "beta" and imports the update for newer versions. For the sake of reproducibility, I have left the main branch untouched.