patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.44k stars 130 forks source link

vmapping diffeqsolve of a neural ODE increases number of steps (on RTX 3090) #213

Closed jaschau closed 1 year ago

jaschau commented 1 year ago

Hi,

I am encountering a really weird behavior.

I am solving a neural ODE with a randomly initialized 2-layer MLP with 100 hidden nodes and 10 components (with fixed seed). The 10 components of the solution have the following dynamics, see the minimum example below.

neural_ode_solution

The dots show the steps taken by the solver. I expected the number of steps to be consistent between different GPUs and CPUs. I have tested this this on several GPUs: a Tesla P100, a Quadro P5000, a Tesla V100 and a RTX 3090. Indeed, this is what I see - in most cases. More concretely, if I solve the neural ODE once without vmapping the diffeqsolve, the number of steps agree between CPU and GPU. In the example below, I obtain 14 steps. When I vmap the diffeqsolve and solve the same ODE twice, the number of steps stays the same on CPU and almost all GPUs. But on the RTX 3090, the number of steps increases to 162, leading to the plot shown below

neural_ode_solution_RTX_3090

I have reproduced this behaviour on two different RTX 3090, so it seems it's not related to a hardware issue. I use the latest diffrax and jax versions at the time of writing, i.e.,

diffrax                      0.2.2
jax                          0.4.1
jaxlib                       0.4.1+cuda11.cudnn86

I am using CUDA 11.1 and CUDNN 8.2.1 on Ubuntu 20 LTS.

So I am currently very puzzled what is happening here. Generally, I could understand floating point accuracy being lower in a consumer grade GPU like the RTX 3090. But then I would expect the number of steps to be the same irrespective of whether I use vmap on the diffeqsolve. The only things I could think of is that

I know that this is a tough issue to reproduce because it seems to be tied to a specific hardware. So if anyone is reading this who in possession of the RTX 3090 or any other consumer-grade GPU, it would be great if you could see if you can reproduce the issue with the exampe below. I'd appreciate any insights into what could cause such a behaviour.

See below for the reproduction example:

import diffrax
import jax
import jax.random as jrandom
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

# %%
cpu_device = jax.devices("cpu")[0]
gpu_device = jax.devices("gpu")[0]

# %% [markdown]
# # Define a simple 2-layer neural ODE with 10 components

# %%
import equinox as eqx
import jax.nn as jnn

class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=10,
            out_size=10,
            width_size=100,
            depth=2,
            activation=jnn.softplus,
            key=key,
        )

    def __call__(self, t, y, args):
        return self.mlp(y)

class NeuralODE(eqx.Module):
    func: Func

    def __init__(self, key, **kwargs):
        super().__init__(**kwargs)
        self.func = Func(key=key)

    def __call__(self, y0):
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Tsit5(),
            t0=0.,
            t1=50.,
            dt0=0.02,
            y0=y0,
            stepsize_controller=diffrax.PIDController(rtol=1e-4, atol=1e-6),
            saveat=diffrax.SaveAt(steps=True),
        )
        return solution

# %% [markdown]
# # Evaluate w/o batching on CPU

# %%
with jax.default_device(cpu_device):
    key = jrandom.PRNGKey(seed=42)
    neural_ode = NeuralODE(key=key)

    solution_cpu = neural_ode(jnp.zeros(10))
    num_steps_cpu = int(solution_cpu.stats["num_steps"])

# outputs 14
print("number of steps on CPU without batching: {:}".format(num_steps_cpu))

# %% [markdown]
# Visualize steps and solution

# %%
nrows = 2
ncols = 5
fig, axes = plt.subplots(nrows, ncols, figsize=(16, 6))
for i in range(10):
    row = i // ncols
    column = i - row * ncols
    axes[row][column].plot(solution_cpu.ts, solution_cpu.ys[:, i], "o-")

# %% [markdown]
# # Evaluate w/o batching on GPU

# %%
with jax.default_device(gpu_device):
    key = jrandom.PRNGKey(seed=42)
    neural_ode = NeuralODE(key=key)

    solution_gpu = neural_ode(jnp.zeros(10))
    num_steps_gpu = int(solution_gpu.stats["num_steps"])

# outputs 14
print("number of steps on GPU without batching: {:}".format(num_steps_gpu))

# %% [markdown]
# Visualize steps and solution 

# %%
nrows = 2
ncols = 5
fig, axes = plt.subplots(nrows, ncols, figsize=(16, 6))
for i in range(10):
    row = i // ncols
    column = i - row * ncols
    axes[row][column].plot(solution_gpu.ts, solution_gpu.ys[:, i], "o-")

# %% [markdown]
# ## Finding: same number of steps between GPU and CPU

# %% [markdown]
# # Evaluate w/ batching on CPU

# %%
with jax.default_device(cpu_device):
    key = jrandom.PRNGKey(seed=42)
    neural_ode = NeuralODE(key=key)

    solution_cpu_batched = jax.vmap(neural_ode)(jnp.zeros((2, 10)))
    num_steps_cpu_batched = np.array(solution_cpu_batched.stats["num_steps"])

# outputs [14 14], so same as before
print("number of steps on CPU with vmap (batch size 2): {:}".format(num_steps_cpu_batched))

# %% [markdown]
# # Evaluate w/ batching on GPU

# %%
with jax.default_device(gpu_device):
    key = jrandom.PRNGKey(seed=42)
    neural_ode = NeuralODE(key=key)

    # use a batch size of 2
    solution_gpu_batched = jax.vmap(neural_ode)(jnp.zeros((2, 10)))
    num_steps_gpu_batched = np.array(solution_gpu_batched.stats["num_steps"])

# outputs [162 162] on the RTX 3090, so much higher number of steps
print("number of steps on GPU with vmap (batch size 2): {:}".format(num_steps_gpu_batched))

# %%
nrows = 2
ncols = 5
fig, axes = plt.subplots(nrows, ncols, figsize=(16, 6))
for i in range(10):
    row = i // ncols
    column = i - row * ncols
    axes[row][column].plot(solution_gpu_batched.ts[0], solution_gpu_batched.ys[0, :, i], "o-")
patrick-kidger commented 1 year ago

I think I might know what's going on here.

Inside an MLP applied to a single batch element, each dense layer performs a matrix-vector product, between weight and data. When vmap'd, this turns into a matrix-matrix product, between weight and data.

This matrix-matrix product doesn't respect any notion of batch-element-independence; at the floating-point level it may produce slightly different results than a concatenation of several matrix-vector products. And this is precisely the kind of thing depends on the choice of hardware, the choice of matmul algorithm used, etc. Moreover this issue usually tends to go away when switch to float64 (as I believe you've observed), since at that point any discrepancies happen at a much smaller scale.

One possible fix might be to go through and rewrite the jaxpr so that all matmuls occur at higher-than-default precision, see here. Let me know whether you're comfortable doing jaxpr rewrites and if not I can try throwing something together. (No guarantees that this will fix it though. I've only bumped into this issue in PyTorch before, and there we fixed it by just switching to float64.)

patrick-kidger commented 1 year ago

That said, it is pretty weird that this is affecting the differnetial equation numerics as much as it is. Can you check what the error estimates (y_error as returned from Tsit5().step) look like in each case?

jaschau commented 1 year ago

Thanks for the quick response. Your answer pointed me in the right direction and I think I have been able to resolve this issue. Let's look at the y_error estimates on GPU (orange) and CPU (blue) with constant step size dt0=3. over time:

neural_ode_error_estimates

Wow! The GPU error is 3-4 orders of magnitude higher on GPU compared to CPU. That very much feels like catastrophic cancellation. Indeed, as reviewed, e.g., here, the error estimates made by Runge-Kutta solvers $y_{err} = \tilde y(t + dt0) - y(t + dt0)$ are obtained by subtracting two different estimates $y(t + dt0)$, $\tilde y(t + dt0)$ of the value of the differential equation at $t + dt0$. So how come the floating point precision of $y(t + dt0)$ and $\tilde y(t + dt0)$ is so much lower on the RTX 3090 than on the CPU?

It turns out that with the recent Ampère generation of GPUs like the RTX 3090, NVIDIA introduced the TensorFloat32 floating point format which unlike the name might suggest is not a IEEE 754 floating point format. Instead, it's a 19-bit format which only uses 10 bits for the mantissa as opposed to 23 bits in conventional IEEE 754 floating point, see here. Now, according to https://github.com/google/jax/issues/4873, use of TensorFloat32 is automatically enabled in Jax for matrix multiplication. As mentioned in the Jax issue, there is an environment variable NVIDIA_TF32_OVERRIDE=0 which disables use of TensorFloat32. Indeed, if I set this and run the reproduction example, I obtain

number of steps on CPU without batching: 14
number of steps on GPU without batching: 14
number of steps on CPU with vmap (batch size 2): [14 14]
number of steps on GPU with vmap (batch size 2): [14 14]

also on the RTX 3090.

I guess for many machine learning applications, the reduced precision does not matter. But it looks like for solving Neural ODEs with adaptive step solvers, it hurts performance. So maybe it's worth adding some warning to the documentation to make users aware of this?

patrick-kidger commented 1 year ago

That's interesting! It's great to have gotten to the bottom of this so swiftly.

Catastrophic cancellation is undoubtedly the reason here, but just in case it's interesting -- it is possible to compute the error estimate a little more numerically stably than by doing $y_{err} = y_1 - \widetilde{y_1}$. (Not that this was enough to help here!)

In a Runge--Kutta solver then the output of a step is given by some linear combination of vector field evaluations, i.e. $y_1 = y0 + \sum{i=1,...,k} \alpha_i f_i$. Likewise the embedded estimate has $\widetilde{y_1} = y0 + \sum{i=1,...,k} \widetilde{\alpha_i} fi$. So instead of computing $y{err} = y_1 - \widetilde{y1}$, we instead do $y{err} = \sum_{i=1,...,k} (\alpha_i - \widetilde{\alpha_i}) f_i$.

In terms of things to do here: probably we want to force precision=HIGHEST for either (a) the linear combination described above, or (b) all dot_general operations happening inside the diffeq solver (just to be sure). If you get a chance, could you try passing precision=jax.lax.Precision.HIGHEST to the tensordot operation here and see if that also fixes the issue?

jaschau commented 1 year ago

Thanks for the clarification about the error estimate. I was already wondering whether the naive formulation was in fact what people use in implementation.

Injecting precision=jax.lax.Precision.HIGHEST does not seem to fix the issue. However, even if I perform the computation within the context manager https://jax.readthedocs.io/en/latest/_autosummary/jax.default_matmul_precision.html in the following form

with jax.default_matmul_precision("float32"):
    .... # actual computation

then precision is still low. So I'm not sure if the precision arguments have an effect in jax at the moment when you are not using a TPU.

Just for completeness, here's the code I used for generating the error plots shown above:

def solve_and_track_errors():
    y_errors = []
    ys = []
    ts = []

    key = jrandom.PRNGKey(seed=42)
    neural_ode = NeuralODE(key=key)
    vmap_mlp = jax.vmap(neural_ode.func.mlp)

    solver = diffrax.Tsit5()
    term = diffrax.ODETerm(lambda t, y, args: vmap_mlp(y))

    args = None
    t0 = 0.
    t1 = 50.
    dt0 = 3.
    y0 = jnp.zeros((2, 10,))

    tprev = t0
    tnext = t0 + dt0
    y = y0
    state = solver.init(term, tprev, tnext, y0, args)

    while tprev < t1:
        y, y_error, _, state, _ = solver.step(term, tprev, tnext, y, args, state, made_jump=False)

        y_errors.append(np.array(y_error))
        ys.append(np.array(y))
        ts.append(tnext)

        tprev = tnext
        tnext = min(tprev + dt0, t1)
    return np.asarray(ts), np.asarray(ys), np.asarray(y_errors)

with jax.default_device(cpu_device):
    ts_cpu, ys_cpu, y_errors_cpu = solve_and_track_errors()
with jax.default_matmul_precision("float32"):
    with jax.default_device(gpu_device):
        ts_gpu, ys_gpu, y_errors_gpu = solve_and_track_errors()
nrows = 2
ncols = 5
fig, axes = plt.subplots(nrows, ncols, figsize=(16, 6), sharey=True)
for i in range(10):
    row = i // ncols
    column = i - row * ncols
    axes[row][column].semilogy(ts_cpu, np.abs(y_errors_cpu[:, 0, i]), "-")
    axes[row][column].semilogy(ts_gpu, np.abs(y_errors_gpu[:, 0, i]), "-", alpha=0.9)
patrick-kidger commented 1 year ago

If I understand the JAX docs correctly then precision should affect both GPU and TPU. If you can reduce to a MWE then I'd suggest opening a bug on the JAX page.

jaschau commented 1 year ago

Sounds good, I have created a bug report here https://github.com/google/jax/issues/14022.

jaschau commented 1 year ago

Hi, I can now give an update. The jax issue linked above has been fixed in jaxlib 0.4.3 and the precision argument for matmul operations is now respected on GPUs. When I use the context manager with jax.default_matmul_precision("float32"): in the example above:

with jax.default_device(gpu_device):
    key = jrandom.PRNGKey(seed=42)
    neural_ode = NeuralODE(key=key)

    # use a batch size of 2
    with jax.default_matmul_precision("float32"):
        solution_gpu_batched = jax.vmap(neural_ode)(jnp.zeros((2, 10)))
        num_steps_gpu_batched = np.array(solution_gpu_batched.stats["num_steps"])

# outputs [14 14]
print("number of steps on GPU with vmap (batch size 2) and float32 default matmul precision: {:}".format(num_steps_gpu_batched))

the number of steps taken is 14 on GPU as expected. However, if I change vector_tree_dot in solver/base.py as

def vector_tree_dot(a, b):
    return jtu.tree_map(lambda bi: jnp.tensordot(a, bi, axes=1, precision="float32"), b)

I still get 162 steps without the context manager, so it seems we lose the relevant precision outside the solver in the neural network.

jaschau commented 1 year ago

Btw, if I only put the context manager in Func.__call__ before calling the neural network, the number of steps also stays at 14. So it seems the reduced precision in the solver is less critical than the reduced precision in the vector field, at least in this toy problem.

patrick-kidger commented 1 year ago

Interesting, and thanks for the update. And FWIW, the next release of Diffrax will include precision=jax.lax.Precision.HIGHEST in the vector_tree_dot call, so I think we'll then be doing all we can to ensure the appropriate precision is used!

IIUC I'm going to close this issue as resolved. (Please re-open if I'm wrong.)