jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.48k stars 2.8k forks source link

Gradient Bug (Possibly to do with `eigh`). #10844

Open sschoenholz opened 2 years ago

sschoenholz commented 2 years ago

We've been working on some code to simulate rigid bodies using JAX and came across a gradient bug (in the sense that jax.check_grads fails. Unfortunately, it was fairly difficult to minimize the repro so it's a bit long. I'm happy to iterate to try to narrow down the problem if it would be helpful.

Thanks very much for any help!

Here's the code, run using the most recent version of JAX (v 0.3.14) both on Colab and Desktop:

import jax.numpy as jnp
import jax.test_util as jtu
from jax import random
from jax import vmap, jit
from functools import partial

from collections import namedtuple

RigidBody = namedtuple('RigidBody', ['center', 'orientation'])

Array = jnp.ndarray

@partial(jnp.vectorize, signature='(q),(q)->(q)')
def _quaternion_multiply(lhs: Array, rhs: Array) -> Array:
  wl, xl, yl, zl = lhs
  wr, xr, yr, zr = rhs

  return jnp.array([
      -xl * xr - yl * yr - zl * zr + wl * wr,
      xl * wr + yl * zr - zl * yr + wl * xr,
      -xl * zr + yl * wr + zl * xr + wl * yr,
      xl * yr - yl * xr + zl * wr + wl * zr
  ])

@partial(jnp.vectorize, signature='(q)->(q)')
def _quaternion_conjugate(q: Array) -> Array:
  w, x, y, z = q
  return jnp.array([w, -x, -y, -z], dtype=q.dtype)

@partial(jnp.vectorize, signature='(q),(d)->(d)')
def _quaternion_apply(q: Array, v: Array) -> Array:
  if q.shape != (4,):
    raise ValueError('')
  if v.shape != (3,):
    raise ValueError('')

  v = jnp.concatenate([jnp.zeros((1,), v.dtype), v])
  q = _quaternion_multiply(q, _quaternion_multiply(v, _quaternion_conjugate(q)))
  return q[1:]

def transform(body: RigidBody, points: Array) -> Array:
  position, orientation = body
  return position[None, :] + _quaternion_apply(orientation, points)
transform = vmap(transform, (0, None))

def moment_of_inertia(points):
  ndim = points.shape[-1]
  I_sphere = 2 / 5
  @vmap
  def per_particle(point):
    diagonal = jnp.linalg.norm(point) ** 2 * jnp.eye(point.shape[-1])
    off_diagonal = point[:, None] * point[None, :]
    return ((diagonal - off_diagonal) + jnp.eye(3) * I_sphere)
  return jnp.sum(per_particle(points), axis=0)

def transform_to_diagonal_frame(shape_points):
  I = moment_of_inertia(shape_points)
  I_diag, U = jnp.linalg.eigh(I)
  shape_points = jnp.einsum('ni,ij->nj', shape_points, U)
  return shape_points

def _diagonal_mask(X: Array) -> Array:
  """Sets the diagonal of a matrix to zero."""
  if X.shape[0] != X.shape[1]:
    raise ValueError(
        'Diagonal mask can only mask square matrices. Found {}x{}.'.format(
            X.shape[0], X.shape[1]))
  if len(X.shape) > 3:
    raise ValueError(
        ('Diagonal mask can only mask rank-2 or rank-3 tensors. '
         'Found {}.'.format(len(X.shape))))
  N = X.shape[0]
  # NOTE(schsam): It seems potentially dangerous to set nans to 0 here. However,
  # masking nans also doesn't seem to work. So it also seems necessary. At the
  # very least we should do some @ErrorChecking.
  X = jnp.nan_to_num(X)
  mask = 1.0 - jnp.eye(N, dtype=X.dtype)
  if len(X.shape) == 3:
    mask = jnp.reshape(mask, (N, N, 1))
  return mask * X

def safe_mask(mask, fn, operand, placeholder=0):
  masked = jnp.where(mask, operand, 0)
  return jnp.where(mask, fn(masked), placeholder)

def energy_fn(body, shape_points):
  R = transform(body, shape_points)
  R = jnp.reshape(R, (-1, R.shape[-1]))
  dr_2 = jnp.sum((R[:, None, :] - R[None, :, :]) ** 2, axis=-1)
  dr = safe_mask(dr_2 > 0, jnp.sqrt, dr_2)
  e = 0.5 * jnp.where(dr < 1.0, (1.0 - dr) ** 2, 0.0)
  return 0.5 * jnp.sum(_diagonal_mask(e))

def shape_energy_fn(shape_points):
  body = RigidBody(
    jnp.array([[0.0, 0.0, 0.0],
               [0.5, 0.25, 0.15]]),
    jnp.array([[1.0, 0.0, 0.0, 0.0],
               [1.0, 0.1, 0.0, 0.0]])
  )

  shape_points = transform_to_diagonal_frame(shape_points)
  return energy_fn(body, shape_points)

points = jnp.array([[-0.5, -0.5, -0.5],
                    [-0.5, -0.5,  0.5],
                    [ 0.5, -0.5, -0.5],
                    [ 0.5, -0.5,  0.5],
                    [-0.5,  0.5, -0.5],
                    [-0.5,  0.5,  0.5],
                    [ 0.5,  0.5, -0.5],
                    [ 0.5,  0.5,  0.5]]))

# If these two lines are swapped then the test passes. The only 
# difference between the two is whether or not the moment of inertia
# tensor (that is passed into `eigh`) has eigenvalues that are already ordered
# or whether they are out of order.
points = points * jnp.array([[1.0, 1.1, 1.2]])
# points = points * jnp.array([[1.2, 1.1, 1.0]])

jtu.check_grads(shape_energy_fn, (points,), 1)

This asserts with the error

Not equal to tolerance rtol=0.002, atol=0.002
JVP tangent
Mismatched elements: 1 / 1 (100%)
Max absolute difference: 0.5010784
Max relative difference: 1.2183625
 x: array(0.089806, dtype=float32)
 y: array(-0.411272, dtype=float32)
jakevdp commented 2 years ago

Thanks - if the gradient issue is with eigh, can we reproduce by calling check_grads on transform_to_diagonal_frame? That would pretty significantly reduce the size of the repro.

sschoenholz commented 2 years ago

Ah, good point, it does! However, the reason I had earlier thought the error didn't show up with transform_to_diagonal_frame is that I had been testing on GPU. It seems like this bug might only show up on the CPU backend.

Here's the repro only with transform_to_diagonal_frame:

import jax.numpy as jnp
import jax.test_util as jtu
from jax import random
from jax import vmap, jit
from functools import partial

from collections import namedtuple

def moment_of_inertia(points):
  ndim = points.shape[-1]
  I_sphere = 2 / 5
  @vmap
  def per_particle(point):
    diagonal = jnp.linalg.norm(point) ** 2 * jnp.eye(point.shape[-1])
    off_diagonal = point[:, None] * point[None, :]
    return ((diagonal - off_diagonal) + jnp.eye(3) * I_sphere)
  return jnp.sum(per_particle(points), axis=0)

def transform_to_diagonal_frame(shape_points):
  I = moment_of_inertia(shape_points)
  I_diag, U = jnp.linalg.eigh(I)
  shape_points = jnp.einsum('ni,ij->nj', shape_points, U)
  return shape_points

points = jnp.array([[-0.5, -0.5, -0.5],
                    [-0.5, -0.5,  0.5],
                    [ 0.5, -0.5, -0.5],
                    [ 0.5, -0.5,  0.5],
                    [-0.5,  0.5, -0.5],
                    [-0.5,  0.5,  0.5],
                    [ 0.5,  0.5, -0.5],
                    [ 0.5,  0.5,  0.5]])

# If these two lines are swapped then the test passes. The only 
# difference between the two is whether or not the moment of inertia
# tensor (that is passed into `eigh`) has eigenvalues that are already ordered
# or whether they are out of order.
points = points * jnp.array([[1.0, 1.1, 1.2]])
# points = points * jnp.array([[1.2, 1.1, 1.0]])

jtu.check_grads(transform_to_diagonal_frame, (points,), 1)
YouJiacheng commented 2 years ago

Hmm, I can't repro on CPU...(with JAX 0.3.13, sorry I didn't notice that you use an unreleased version)

jakevdp commented 2 years ago

Thanks @sschoenholz for the more concise version! This repros for me on a Colab CPU runtime, as well as on my own macbook.

rafael-fuente commented 2 years ago

Thanks - if the gradient issue is with eigh, can we reproduce by calling check_grads on transform_to_diagonal_frame? That would pretty significantly reduce the size of the repro.

It's likely that in this example the bug is related with eigh. I open a new issue proving that eigh gradients can be wrong: https://github.com/google/jax/issues/10877