Open sschoenholz opened 2 years ago
Thanks - if the gradient issue is with eigh
, can we reproduce by calling check_grads
on transform_to_diagonal_frame
? That would pretty significantly reduce the size of the repro.
Ah, good point, it does! However, the reason I had earlier thought the error didn't show up with transform_to_diagonal_frame
is that I had been testing on GPU. It seems like this bug might only show up on the CPU backend.
Here's the repro only with transform_to_diagonal_frame
:
import jax.numpy as jnp
import jax.test_util as jtu
from jax import random
from jax import vmap, jit
from functools import partial
from collections import namedtuple
def moment_of_inertia(points):
ndim = points.shape[-1]
I_sphere = 2 / 5
@vmap
def per_particle(point):
diagonal = jnp.linalg.norm(point) ** 2 * jnp.eye(point.shape[-1])
off_diagonal = point[:, None] * point[None, :]
return ((diagonal - off_diagonal) + jnp.eye(3) * I_sphere)
return jnp.sum(per_particle(points), axis=0)
def transform_to_diagonal_frame(shape_points):
I = moment_of_inertia(shape_points)
I_diag, U = jnp.linalg.eigh(I)
shape_points = jnp.einsum('ni,ij->nj', shape_points, U)
return shape_points
points = jnp.array([[-0.5, -0.5, -0.5],
[-0.5, -0.5, 0.5],
[ 0.5, -0.5, -0.5],
[ 0.5, -0.5, 0.5],
[-0.5, 0.5, -0.5],
[-0.5, 0.5, 0.5],
[ 0.5, 0.5, -0.5],
[ 0.5, 0.5, 0.5]])
# If these two lines are swapped then the test passes. The only
# difference between the two is whether or not the moment of inertia
# tensor (that is passed into `eigh`) has eigenvalues that are already ordered
# or whether they are out of order.
points = points * jnp.array([[1.0, 1.1, 1.2]])
# points = points * jnp.array([[1.2, 1.1, 1.0]])
jtu.check_grads(transform_to_diagonal_frame, (points,), 1)
Hmm, I can't repro on CPU...(with JAX 0.3.13, sorry I didn't notice that you use an unreleased version)
Thanks @sschoenholz for the more concise version! This repros for me on a Colab CPU runtime, as well as on my own macbook.
Thanks - if the gradient issue is with
eigh
, can we reproduce by callingcheck_grads
ontransform_to_diagonal_frame
? That would pretty significantly reduce the size of the repro.
It's likely that in this example the bug is related with eigh. I open a new issue proving that eigh gradients can be wrong: https://github.com/google/jax/issues/10877
We've been working on some code to simulate rigid bodies using JAX and came across a gradient bug (in the sense that
jax.check_grads
fails. Unfortunately, it was fairly difficult to minimize the repro so it's a bit long. I'm happy to iterate to try to narrow down the problem if it would be helpful.Thanks very much for any help!
Here's the code, run using the most recent version of JAX (v 0.3.14) both on Colab and Desktop:
This asserts with the error