rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.62k stars 932 forks source link

ODE demo not working #168

Closed azournas closed 3 years ago

azournas commented 3 years ago

Hello. This is a similar issue with #167. I am having the same issue when trying to run it locally (RTX 3080, pytorch 1.9.0.dev20210505, torchdiffeq 0.2.1) with the ode_demo.py case. Works fine on a CPU, but when on the GPU I get the following error:

Traceback (most recent call last):

File "C:\Users\apost\OneDrive\Documents\GitHub\torchdiffeq\examples\ode_demo.py", line 41, in true_y = odeint(Lambda(), true_y0, t, method='dopri5')

File "C:\Users\apost\miniconda3\envs\torch_nightly\lib\site-packages\torchdiffeq-0.2.1-py3.8.egg\torchdiffeq_impl\odeint.py", line 77, in odeint solution = solver.integrate(t)

File "C:\Users\apost\miniconda3\envs\torch_nightly\lib\site-packages\torchdiffeq-0.2.1-py3.8.egg\torchdiffeq_impl\solvers.py", line 30, in integrate solution[i] = self._advance(t[i])

File "C:\Users\apost\miniconda3\envs\torch_nightly\lib\site-packages\torchdiffeq-0.2.1-py3.8.egg\torchdiffeq_impl\rk_common.py", line 194, in _advance self.rk_state = self._adaptive_step(self.rk_state)

File "C:\Users\apost\miniconda3\envs\torch_nightly\lib\site-packages\torchdiffeq-0.2.1-py3.8.egg\torchdiffeq_impl\rk_common.py", line 255, in _adaptive_step y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, t1, tableau=self.tableau)

File "C:\Users\apost\miniconda3\envs\torch_nightly\lib\site-packages\torchdiffeq-0.2.1-py3.8.egg\torchdiffeq_impl\rk_common.py", line 76, in _runge_kutta_step f = func(ti, yi, perturb=perturb)

File "C:\Users\apost\miniconda3\envs\torch_nightly\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl return forward_call(*input, **kwargs)

File "C:\Users\apost\miniconda3\envs\torch_nightly\lib\site-packages\torchdiffeq-0.2.1-py3.8.egg\torchdiffeq_impl\misc.py", line 187, in forward t = _nextafter(t, self._neginf)

File "C:\Users\apost\miniconda3\envs\torch_nightly\lib\site-packages\torchdiffeq-0.2.1-py3.8.egg\torchdiffeq_impl\misc.py", line 321, in _nextafter out = torch.nextafter(x1, x2)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking arugment for argument other in method wrapper_nextafter)

azournas commented 3 years ago

Update: Found the potential source of this error.

in the class _PerturbFunc the variables _inf and _neginf are defined but never transferred to the GPU. making the following change seems to solve the issue, but I am pretty sure this is not the efficient way to do this:

in misc.py:

from:

class _PerturbFunc(torch.nn.Module): _inf = torch.tensor(math.inf) _neginf = torch.tensor(-math.inf)

to:

class _PerturbFunc(torch.nn.Module): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") _inf = torch.tensor(math.inf).to(device) _neginf = torch.tensor(-math.inf).to(device)

I also tried the following instead of the previous change:

in the check)inputs function: from:

Add perturb argument to func.

func = _PerturbFunc(func)
return shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed

to:

Add perturb argument to func.

func = _PerturbFunc(func)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
func._inf = func._inf.to(device)
func._neginfi = func._neginf.to(device)
return shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed

This, however makes this significantly slower when running on a GPU. I am not sure if the problem in the example is just too small and transferring on the GPU becomes the bottleneck or if there is a way to do this more efficiently.

rtqichen commented 3 years ago

Thanks for looking into this! I'll fix this asap.

Right, the GPU is best for speeding up neural network evaluations within the ODE function. It doesn't deal with sequential operations very well and every operation on the GPU has a small overhead.

rtqichen commented 3 years ago

Should be fixed on master and pypi now. Thanks!

azournas commented 3 years ago

Thank you!