jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.6k stars 2.82k forks source link

Gradients with `odeint` slow on GPU #5006

Open spenrich opened 4 years ago

spenrich commented 4 years ago

The following MWE trains a simple neural ODE model with gradient descent to match a 2-D dynamical system (Van der Pol oscillator) with sampled data along a single trajectory. Each iteration of the training loop runs slowly on my GPU when compared to running everything on my CPU (roughly estimated with tqdm at 17 iterations/sec on GPU vs. upwards of 800 iterations/sec on CPU).

Any first impressions about what might be going on? I can look into doing better profiling if need be.

Versions: jax 0.2.6, jaxlib 0.1.57+cuda102, cuda 10.2

import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
try:
    from tqdm import tqdm
except ImportError:
    tqdm = lambda x: x

# Uncomment this line to force using the CPU
# jax.config.update('jax_platform_name', 'cpu')

# Some utilities for dealing with PyTrees of parameters
def tree_axpy(a, x_tree, y_tree):
    """Compute `y = a*x` for two PyTrees `(x, y)` and a scalar `a`."""
    ax = jax.tree_util.tree_map(lambda x: a * x, x_tree)
    axpy = jax.tree_util.tree_multimap(lambda x, y: x + y, ax, y_tree)
    return axpy

def tree_normsq(x_tree):
    """Compute sum of squared norms across a PyTree."""
    normsq = jax.tree_util.tree_reduce(lambda x, y: x + jnp.sum(y**2), x_tree, 0.)
    return normsq

# Define true ODE, our approximator, and the loss function
def f(x, t):
    """Compute state derivative of a Van der Pol oscillator."""
    mu = 1.
    dx = jnp.hstack([
        mu*(x[0] - x[0]**3/3 - x[1]),
        x[0]/mu
    ])
    return dx

def f_est(x, t, params):
    """Estimate state derivative with a two-layer neural network."""
    W = params['W']
    b = params['b']
    y = W[0]@x + b[0]
    y = W[1]@jnp.tanh(y) + b[1]
    return y

def loss(params, x, t, reg_coeff):
    """Compute the sum of squared losses along a queried trajectory."""
    x_hat = odeint(f_est, x[0], t, params)
    error = jnp.sum((x - x_hat)**2)
    loss_value = error + reg_coeff*tree_normsq(params)
    return loss_value

# Generate data along a trajectory of the true system
x0 = jnp.array([1., 0.])
t0, tf = (0., 5.)
dt = 0.1
num_steps = int((tf - t0) / dt) + 1
t = jnp.linspace(t0, tf, num_steps)
x = odeint(f, x0, t)

# Initialize neural network parameters
n = 2
hdim = 32  # size of hidden layer
key = jax.random.PRNGKey(0)
params = {
    'W': [
        0.1*jax.random.normal(key, (hdim, n)),
        0.1*jax.random.normal(key, (n, hdim)),
    ],
    'b': [
        0.1*jax.random.normal(key, (hdim,)),
        0.1*jax.random.normal(key, (n,)),
    ]
}

# Training
loss_buffer = []
step_size = 1e-4
reg_coeff = 1e-6
value_and_grad = jax.jit(jax.value_and_grad(loss))
for _ in tqdm(range(5000)):
    value, grad = value_and_grad(params, x, t, reg_coeff)
    loss_buffer.append(value)
    params = tree_axpy(-step_size, grad, params)  # gradient descent step
print('Regularized fit loss:', loss_buffer[-1])

# Plotting (optional)
try:
    import matplotlib.pyplot as plt

    x_est = odeint(f_est, x0, t, params)

    fig, axes = plt.subplots(1, 2, figsize=(15,5))
    axes[0].plot(x[:,0], x[:,1], '--x')
    axes[0].plot(x_est[:,0], x_est[:,1], '-')
    axes[1].plot(loss_buffer)
    axes[1].set_yscale('log')
    plt.show()
except ImportError:
    print('Package `matplotlib` not found! Skipping plots.')
shoyer commented 4 years ago

The short answer is that unfortunately at this time XLA GPU is not great at code generation for tight loops like those in odeint. The body of the while_loop is compiled into one or more GPU kernels, which has significant launch overhead because control flow goes back to the CPU in each iteration.

awav commented 3 years ago

@shoyer does anyone work on improving while_loop?

shoyer commented 3 years ago

Yes, there are several ongoing streams of work to improve while_loop.

awav commented 3 years ago

@shoyer, apologies for unrelated with the topic questions. Could you share the links to PRs, branches to ongoing work if it's publicly available?

shoyer commented 3 years ago

Sorry, I don't have any details that I can share at this time.

On Wed, Dec 9, 2020 at 12:37 PM Artem Artemev notifications@github.com wrote:

@shoyer https://github.com/shoyer, apologies for unrelated with the topic questions. Could you share the links to PRs, branches to ongoing work if it's publicly available?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/5006#issuecomment-742033766, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJJFVSQQ6RAS2NDDY2AH5DST7NXDANCNFSM4UBHJBMQ .