rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.53k stars 919 forks source link

torchdiffeq is slower than scipy's solve_ivp #200

Open ma-sadeghi opened 2 years ago

ma-sadeghi commented 2 years ago

Here's a sample script that solves the 1D heat equation discretized using the method of lines. torchdiffeq turned out to be ~7x slower, which is not what I expected. I was expecting similar performance at worst (not even, since torchdiffeq was run on GPU). Am I missing something? Thank you so much!

import time
import torch
from torchdiffeq import odeint
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
print(f"GPU available: {torch.cuda.is_available()}")

def func(t, y, dy):
    N = len(y)
    h = 1/(N-1)
    idx = np.arange(1, N-1)
    dy[idx] = (y[idx-1] - 2*y[idx] + y[idx+1]) / h**2
    return dy

N = 300
y0 = torch.ones(N)
y0[0] = y0[-1] = 0.0
tspan = (0, 0.05)
t = torch.linspace(*tspan, 10)
dy = torch.zeros_like(y0)

start = time.time()
sol = odeint(lambda t, y: func(t, y, dy), y0, t)
print(f"torchdiffeq: {time.time()-start:.2f} s")

start = time.time()
dy = np.zeros_like(y0)
odefunc = lambda t, y: func(t, y, dy)
sol2 = solve_ivp(odefunc, tspan, y0.numpy(), t_eval=t, method="RK45")
print(f"solve_ivp: {time.time()-start:.2f} s")
GPU available: True
torchdiffeq: 8.37 s
solve_ivp: 1.16 s
MadHuslista commented 1 year ago

Not remotely and expert, but maybe the improvement appears only to compare the backpropagation speed and error using both methods. And also appears that you're not taking advantage of the adjoint way to solve the odeint.