tksmatsubara / symplectic-adjoint-method

Code for "Symplectic Adjoint Method for Exact Gradient of Neural ODE with Minimal Memory," NeurIPS, 2021.
15 stars 3 forks source link

possible bug at rk4 and dopri5 on cuda device #3

Open Tifus15 opened 1 year ago

Tifus15 commented 1 year ago

Hello I was measuring the performance between of torchdiffeq odeint and your symplectic odeint and I get following error in my code segment:

class Lambda(nn.Module):

     def forward(self, t, y):
         return torch.mm(y, A_ode) 

 y_euler = symp_odeint(Lambda().to(device),x0_ode,t,method='euler').reshape(ndim,steps).T
 y_rk4 = symp_odeint(Lambda().to(device),x0_ode,t,method='rk4').reshape(ndim,steps).T
 y_dopri5 = symp_odeint(Lambda().to(device),x0_ode,t,method='dopri5').reshape(ndim,steps).T

On the cuda device the euler works without issues but I get at rk4 and dopri5 an error that the matrices are not on same device. normal odeint from torchdiffeq doen't show this behaviour.

On cpu is all good.

The error message:

Traceback (most recent call last): File "/home/denis/Desktop/Neural Networks/server stuff/relative_plots.py", line 72, in y_rk4 = symp_odeint(Lambda().to(device),x0_ode,t,method='rk4').reshape(ndim,steps).T File "/home/denis/Desktop/symplectic-adjoint-method/torch_symplectic_adjoint/_impl/symplectic_adjoint.py", line 83, in odeint_symplectic_adjoint solution = OdeintSymplecticAdjoint.apply(shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_params) File "/home/denis/Desktop/symplectic-adjoint-method/torch_symplectic_adjoint/_impl/symplectic_adjoint.py", line 23, in forward y = solver.integrate(t) File "/home/denis/Desktop/symplectic-adjoint-method/torch_symplectic_adjoint/_impl/integrators/solvers.py", line 34, in integrate solution[i] = self._advance(t[i]) File "/home/denis/Desktop/symplectic-adjoint-method/torch_symplectic_adjoint/_impl/integrators/rk_common.py", line 170, in _advance self.rk_state = self._nonadaptive_step(rk_state) File "/home/denis/Desktop/symplectic-adjoint-method/torch_symplectic_adjoint/_impl/integrators/rk_common.py", line 284, in _nonadaptivestep y1, f1, , _ = _runge_kutta_step(self.func, y0, f0, t0, dt, t1, tableau=self.tableau, no_f1=no_f1) File "/home/denis/Desktop/symplectic-adjoint-method/torch_symplectic_adjoint/_impl/integrators/rk_common.py", line 76, in _runge_kutta_step f = func(ti, yi, perturb=perturb) File "/home/denis/anaconda3/envs/hnn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(input, **kwargs) File "/home/denis/Desktop/symplectic-adjoint-method/torch_symplectic_adjoint/_impl/integrators/misc.py", line 206, in forward t = _nextafter(t, self._neginf) File "/home/denis/Desktop/symplectic-adjoint-method/torch_symplectic_adjoint/_impl/integrators/misc.py", line 312, in _nextafter out = torch.nextafter(x1, x2) RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument other in method wrapper_nextafter)

tksmatsubara commented 1 year ago

Please check #2 and try the branch beta.

Tifus15 commented 1 year ago

I switched to beta and It doesn't help. I have a test code which it use torchdiffeq adjoint_odeint with rk4. When I switch to your symplectic integrator I get same error as before. Strange is that for euler I get another error now, at master it worked:

Traceback (most recent call last): File "/home/andric/Desktop/hamiltonian-neural-dynamics/Masterthesis/numerical integration/neural_numerical.py", line 101, in loss.backward() File "/home/andric/anaconda3/envs/hnn/lib/python3.8/site-packages/torch/_tensor.py", line 492, in backward torch.autograd.backward( File "/home/andric/anaconda3/envs/hnn/lib/python3.8/site-packages/torch/autograd/init.py", line 251, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/home/andric/anaconda3/envs/hnn/lib/python3.8/site-packages/torch/autograd/function.py", line 288, in apply return user_fn(self, *args) File "/home/andric/anaconda3/envs/hnn/lib/python3.8/site-packages/torch_symplectic_adjoint-0.0.1-py3.8.egg/torch_symplectic_adjoint/_impl/symplectic_adjoint.py", line 59, in backward File "/home/andric/anaconda3/envs/hnn/lib/python3.8/site-packages/torch_symplectic_adjoint-0.0.1-py3.8.egg/torch_symplectic_adjoint/_impl/integrators/rk_common.py", line 324, in _nonadaptive_step_symplectic_adjoint UnboundLocalError: local variable 'yi' referenced before assignment

tksmatsubara commented 1 year ago

I set up a new environment with Python v3.7.3 using Anaconda and installed PyTorch v1.7.1 and this repository. I tried running the code attached at the end of this comment, and confirmed that it works without any errors.

As mentioned in #2, the following error occurs with PyTorch v1.13.1 and the master branch, but did not with the beta branch:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument other in method wrapper_nextafter)

Please try changing the versions of Python or PyTorch.

import torch
import torch.nn as nn
from torch_symplectic_adjoint import odeint, odeint_adjoint, odeint_symplectic_adjoint

device = "cuda"
N = 5
B = 100
A = torch.randn(N, N, device=device)

class Lambda(nn.Module):
    def forward(self, t, y):
        return torch.mm(y, A)

f = Lambda().to(device)

x0 = torch.randn(B, N, device=device)
t = torch.arange(0, 200, device=device) / 200

y_euler = odeint(f, x0, t, method="euler")
y_rk4 = odeint(f, x0, t, method="rk4")
y_dopri5 = odeint(f, x0, t, method="dopri5")

y_euler = odeint_adjoint(f, x0, t, method="euler")
y_rk4 = odeint_adjoint(f, x0, t, method="rk4")
y_dopri5 = odeint_adjoint(f, x0, t, method="dopri5")

y_euler = odeint_symplectic_adjoint(f, x0, t, method="euler")
y_rk4 = odeint_symplectic_adjoint(f, x0, t, method="rk4")
y_dopri5 = odeint_symplectic_adjoint(f, x0, t, method="dopri5")