Closed jaschau closed 1 year ago
I think I might know what's going on here.
Inside an MLP applied to a single batch element, each dense layer performs a matrix-vector product, between weight and data. When vmap'd, this turns into a matrix-matrix product, between weight and data.
This matrix-matrix product doesn't respect any notion of batch-element-independence; at the floating-point level it may produce slightly different results than a concatenation of several matrix-vector products. And this is precisely the kind of thing depends on the choice of hardware, the choice of matmul algorithm used, etc. Moreover this issue usually tends to go away when switch to float64 (as I believe you've observed), since at that point any discrepancies happen at a much smaller scale.
One possible fix might be to go through and rewrite the jaxpr so that all matmuls occur at higher-than-default precision, see here. Let me know whether you're comfortable doing jaxpr rewrites and if not I can try throwing something together. (No guarantees that this will fix it though. I've only bumped into this issue in PyTorch before, and there we fixed it by just switching to float64.)
That said, it is pretty weird that this is affecting the differnetial equation numerics as much as it is. Can you check what the error estimates (y_error
as returned from Tsit5().step
) look like in each case?
Thanks for the quick response. Your answer pointed me in the right direction and I think I have been able to resolve this issue. Let's look at the y_error estimates on GPU (orange) and CPU (blue) with constant step size dt0=3. over time:
Wow! The GPU error is 3-4 orders of magnitude higher on GPU compared to CPU. That very much feels like catastrophic cancellation. Indeed, as reviewed, e.g., here, the error estimates made by Runge-Kutta solvers $y_{err} = \tilde y(t + dt0) - y(t + dt0)$ are obtained by subtracting two different estimates $y(t + dt0)$, $\tilde y(t + dt0)$ of the value of the differential equation at $t + dt0$. So how come the floating point precision of $y(t + dt0)$ and $\tilde y(t + dt0)$ is so much lower on the RTX 3090 than on the CPU?
It turns out that with the recent Ampère generation of GPUs like the RTX 3090, NVIDIA introduced the TensorFloat32 floating point format which unlike the name might suggest is not a IEEE 754 floating point format. Instead, it's a 19-bit format which only uses 10 bits for the mantissa as opposed to 23 bits in conventional IEEE 754 floating point, see here. Now, according to https://github.com/google/jax/issues/4873, use of TensorFloat32 is automatically enabled in Jax for matrix multiplication.
As mentioned in the Jax issue, there is an environment variable NVIDIA_TF32_OVERRIDE=0
which disables use of TensorFloat32. Indeed, if I set this and run the reproduction example, I obtain
number of steps on CPU without batching: 14
number of steps on GPU without batching: 14
number of steps on CPU with vmap (batch size 2): [14 14]
number of steps on GPU with vmap (batch size 2): [14 14]
also on the RTX 3090.
I guess for many machine learning applications, the reduced precision does not matter. But it looks like for solving Neural ODEs with adaptive step solvers, it hurts performance. So maybe it's worth adding some warning to the documentation to make users aware of this?
That's interesting! It's great to have gotten to the bottom of this so swiftly.
Catastrophic cancellation is undoubtedly the reason here, but just in case it's interesting -- it is possible to compute the error estimate a little more numerically stably than by doing $y_{err} = y_1 - \widetilde{y_1}$. (Not that this was enough to help here!)
In a Runge--Kutta solver then the output of a step is given by some linear combination of vector field evaluations, i.e. $y_1 = y0 + \sum{i=1,...,k} \alpha_i f_i$. Likewise the embedded estimate has $\widetilde{y_1} = y0 + \sum{i=1,...,k} \widetilde{\alpha_i} fi$. So instead of computing $y{err} = y_1 - \widetilde{y1}$, we instead do $y{err} = \sum_{i=1,...,k} (\alpha_i - \widetilde{\alpha_i}) f_i$.
In terms of things to do here: probably we want to force precision=HIGHEST
for either (a) the linear combination described above, or (b) all dot_general
operations happening inside the diffeq solver (just to be sure). If you get a chance, could you try passing precision=jax.lax.Precision.HIGHEST
to the tensordot
operation here and see if that also fixes the issue?
Thanks for the clarification about the error estimate. I was already wondering whether the naive formulation was in fact what people use in implementation.
Injecting precision=jax.lax.Precision.HIGHEST
does not seem to fix the issue.
However, even if I perform the computation within the context manager https://jax.readthedocs.io/en/latest/_autosummary/jax.default_matmul_precision.html
in the following form
with jax.default_matmul_precision("float32"):
.... # actual computation
then precision is still low. So I'm not sure if the precision arguments have an effect in jax at the moment when you are not using a TPU.
Just for completeness, here's the code I used for generating the error plots shown above:
def solve_and_track_errors():
y_errors = []
ys = []
ts = []
key = jrandom.PRNGKey(seed=42)
neural_ode = NeuralODE(key=key)
vmap_mlp = jax.vmap(neural_ode.func.mlp)
solver = diffrax.Tsit5()
term = diffrax.ODETerm(lambda t, y, args: vmap_mlp(y))
args = None
t0 = 0.
t1 = 50.
dt0 = 3.
y0 = jnp.zeros((2, 10,))
tprev = t0
tnext = t0 + dt0
y = y0
state = solver.init(term, tprev, tnext, y0, args)
while tprev < t1:
y, y_error, _, state, _ = solver.step(term, tprev, tnext, y, args, state, made_jump=False)
y_errors.append(np.array(y_error))
ys.append(np.array(y))
ts.append(tnext)
tprev = tnext
tnext = min(tprev + dt0, t1)
return np.asarray(ts), np.asarray(ys), np.asarray(y_errors)
with jax.default_device(cpu_device):
ts_cpu, ys_cpu, y_errors_cpu = solve_and_track_errors()
with jax.default_matmul_precision("float32"):
with jax.default_device(gpu_device):
ts_gpu, ys_gpu, y_errors_gpu = solve_and_track_errors()
nrows = 2
ncols = 5
fig, axes = plt.subplots(nrows, ncols, figsize=(16, 6), sharey=True)
for i in range(10):
row = i // ncols
column = i - row * ncols
axes[row][column].semilogy(ts_cpu, np.abs(y_errors_cpu[:, 0, i]), "-")
axes[row][column].semilogy(ts_gpu, np.abs(y_errors_gpu[:, 0, i]), "-", alpha=0.9)
If I understand the JAX docs correctly then precision
should affect both GPU and TPU. If you can reduce to a MWE then I'd suggest opening a bug on the JAX page.
Sounds good, I have created a bug report here https://github.com/google/jax/issues/14022.
Hi, I can now give an update. The jax issue linked above has been fixed in jaxlib 0.4.3 and the precision argument for matmul operations is now respected on GPUs. When I use the context manager with jax.default_matmul_precision("float32"):
in the example above:
with jax.default_device(gpu_device):
key = jrandom.PRNGKey(seed=42)
neural_ode = NeuralODE(key=key)
# use a batch size of 2
with jax.default_matmul_precision("float32"):
solution_gpu_batched = jax.vmap(neural_ode)(jnp.zeros((2, 10)))
num_steps_gpu_batched = np.array(solution_gpu_batched.stats["num_steps"])
# outputs [14 14]
print("number of steps on GPU with vmap (batch size 2) and float32 default matmul precision: {:}".format(num_steps_gpu_batched))
the number of steps taken is 14 on GPU as expected. However, if I change vector_tree_dot
in solver/base.py
as
def vector_tree_dot(a, b):
return jtu.tree_map(lambda bi: jnp.tensordot(a, bi, axes=1, precision="float32"), b)
I still get 162 steps without the context manager, so it seems we lose the relevant precision outside the solver in the neural network.
Btw, if I only put the context manager in Func.__call__
before calling the neural network, the number of steps also stays at 14. So it seems the reduced precision in the solver is less critical than the reduced precision in the vector field, at least in this toy problem.
Interesting, and thanks for the update. And FWIW, the next release of Diffrax will include precision=jax.lax.Precision.HIGHEST
in the vector_tree_dot
call, so I think we'll then be doing all we can to ensure the appropriate precision is used!
IIUC I'm going to close this issue as resolved. (Please re-open if I'm wrong.)
Hi,
I am encountering a really weird behavior.
I am solving a neural ODE with a randomly initialized 2-layer MLP with 100 hidden nodes and 10 components (with fixed seed). The 10 components of the solution have the following dynamics, see the minimum example below.
The dots show the steps taken by the solver. I expected the number of steps to be consistent between different GPUs and CPUs. I have tested this this on several GPUs: a Tesla P100, a Quadro P5000, a Tesla V100 and a RTX 3090. Indeed, this is what I see - in most cases. More concretely, if I solve the neural ODE once without vmapping the diffeqsolve, the number of steps agree between CPU and GPU. In the example below, I obtain 14 steps. When I vmap the diffeqsolve and solve the same ODE twice, the number of steps stays the same on CPU and almost all GPUs. But on the RTX 3090, the number of steps increases to 162, leading to the plot shown below
I have reproduced this behaviour on two different RTX 3090, so it seems it's not related to a hardware issue. I use the latest diffrax and jax versions at the time of writing, i.e.,
I am using CUDA 11.1 and CUDNN 8.2.1 on Ubuntu 20 LTS.
So I am currently very puzzled what is happening here. Generally, I could understand floating point accuracy being lower in a consumer grade GPU like the RTX 3090. But then I would expect the number of steps to be the same irrespective of whether I use vmap on the diffeqsolve. The only things I could think of is that
I know that this is a tough issue to reproduce because it seems to be tied to a specific hardware. So if anyone is reading this who in possession of the RTX 3090 or any other consumer-grade GPU, it would be great if you could see if you can reproduce the issue with the exampe below. I'd appreciate any insights into what could cause such a behaviour.
See below for the reproduction example: