jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.5k stars 2.8k forks source link

Random nan values from value_and_grad under jit #3335

Closed lumip closed 4 years ago

lumip commented 4 years ago

We have recently encountered a strange bug which manifested itself in nan values resulting from the computation - specifically in the gradient computation - under specific circumstances which I'll try to explain below.

General Setting

Our general setting is that we try to fit a probabilistic mixture model using stochastic variational inference (and the numpyro framework building on jax) but we found our problem to be unrelated to any specifics of that. The following shows a code snippet which we simplified to the point where it does not depend on numpyro and does only a single log-likelihood evaluation but still exhibits the issue.

Code to reproduce

from jax.config import config
config.update("jax_enable_x64", True) # doesn't seem to have any effect

import numpy as onp

import jax
import jax.numpy as np
from jax.scipy.special import logsumexp

k = 5 # nans not for all values (not for 1 and 2, but e.g. for 4, 5, 7)

def normal_loglik(x, loc, log_scale):
    norm_const = -.5 * np.log(2*np.pi)
    loglik = norm_const - log_scale - .5*((x - loc)/np.exp(log_scale))**2
    return loglik

def bernoulli_loglik(x, logits):
    loglik = np.clip(logits, 0) + np.log1p(np.exp(-np.abs(logits))) - logits * x
    return loglik

def model(params, value):
    # we sometimes need per-example gradients, which use vmap,
    # thus need to check if we get a batch or a single instance
    if (value.ndim == 2):
        a_value = value[:, 0, np.newaxis]
        b_value = value[:, 1, np.newaxis]
        c_value = value[:, 2, np.newaxis]
    else:
        a_value = value[0, np.newaxis]
        b_value = value[1, np.newaxis]
        c_value = value[2, np.newaxis]

    a_log_prob = normal_loglik(a_value, params['a_loc'], params['a_scale'])
    b_log_prob = bernoulli_loglik(b_value, params['b_logits'])
    c_log_prob = normal_loglik(c_value, params['c_loc'], params['c_scale'])

    log_phis = a_log_prob + b_log_prob + c_log_prob # this gives nans
    # log_phis = c_log_prob + b_log_prob + a_log_prob # this also gives nans
    # log_phis = a_log_prob + c_log_prob + b_log_prob # this is fine
    # log_phis = c_log_prob + a_log_prob + b_log_prob # this is fine as well
    # any pair of only two is fine

    # temp = log_phis # this is fine
    temp = log_phis - 1. # this gives nans (in practice, we add another log probability here and get nans)
    # temp = log_phis[0:len(log_phis)] - 1. # this is fine again

    #mix_mod_log_prob = np.sum(temp, axis=-1) #this is fine
    mix_mod_log_prob = logsumexp(temp, axis=-1) # this is what we need but gives nans
    return np.mean(mix_mod_log_prob)

## some random data
onp.random.seed(0)
num_data = 300
data_a = onp.random.randn(num_data) * 10.
data_b = onp.random.binomial(1, .5, size=num_data)
data_c = onp.random.randn(num_data) * 10.
data = onp.stack((data_a, data_b, data_c)).T

#######################

params = {
    'a_loc': np.zeros(k),
    'a_scale': np.zeros(k),

    'b_logits': np.zeros(k),

    'c_loc': np.zeros(k),
    'c_scale': np.zeros(k),
}

# currently only computes gradients for whole batch, no real update
@jax.jit # no jit -> no nans
def naive_update_batch(params):
    batch = data
    loss, grads = jax.value_and_grad(model)(params, data)
    return loss, grads

def per_example_value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False):
    value_and_grad_fun = jax.value_and_grad(fun, argnums, has_aux, holomorphic)
    return jax.vmap(value_and_grad_fun, in_axes=(None, 0))

# currently only computes gradients per instance in the batch, no real update
@jax.jit # no jit -> no nans
def naive_update_single_example(params):
    px_loss, px_grads = per_example_value_and_grad(model)(params, data)
    return px_loss, px_grads

# both of these are affected
# px_loss, px_grads = naive_update_batch(params)
px_loss, px_grads = naive_update_single_example(params)

# print({k: np.unique(np.where(np.isnan(v))[0]) for k,v in px_grads.items()})
has_nans = np.any([np.any(np.isnan(v)) for v in px_grads.values()])
assert(not has_nans)

Our Observations

As already mentioned, we have noticed that under certain conditions, the gradients for this contain nan values. Most notably, this only occurs when jit is applied. There are further comments in the code that indicate for which variations we observe nan values and for which we do not. These are presented in blocks of subsequent lines, which are slight variations of each other, and the comments mean the following: If in all the blocks a variant annotated to produce nans is active, we will observe nan values in the gradient. If in any block a variant annotated to be fine, there will be no nan values, regardless of the other blocks.

A summary of the observations

Note that we need per-example gradients and that is where we first noticed the problem, but we noticed that the problem also occurs with regular batch gradients. We left both variants in the code above.

We are somewhat baffled by this behaviour as it seems quite odd indeed. From our observations, we suspect some it might be somehow shape/broadcasting or logsumexp (probably the combination of both) related, but we have no convincing evidence for that. We thus wanted to ask if anything like this has been observed before or if you - as the experts on jax - could shed some light into what might be going wrong here.

skye commented 4 years ago

I'm not able to reproduce this -- can you make sure you're using the latest jax and jaxlib versions? Also are you using CPU, GPU, or TPU? (You can determine this via jax.devices())

lumip commented 4 years ago

Sorry, forgot to mention those details. We are running on CPU and originally this was for

jax==0.1.67
jaxlib==0.1.47

I just did a fresh install in a new environment and still observe the same problem. This is the conda env export:

channels:
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - blas=1.0=mkl
  - ca-certificates=2020.1.1=0
  - certifi=2020.4.5.1=py38_0
  - intel-openmp=2020.1=217
  - ld_impl_linux-64=2.33.1=h53a641e_7
  - libedit=3.1.20181209=hc058e9b_0
  - libffi=3.3=he6710b0_1
  - libgcc-ng=9.1.0=hdf63c60_0
  - libgfortran-ng=7.3.0=hdf63c60_0
  - libstdcxx-ng=9.1.0=hdf63c60_0
  - mkl=2020.1=217
  - mkl-service=2.3.0=py38he904b0f_0
  - mkl_fft=1.0.15=py38ha843d7b_0
  - mkl_random=1.1.1=py38h0573a6f_0
  - ncurses=6.2=he6710b0_1
  - numpy=1.18.1=py38h4f9e942_0
  - numpy-base=1.18.1=py38hde5b4d6_1
  - openssl=1.1.1g=h7b6447c_0
  - pip=20.0.2=py38_3
  - python=3.8.3=hcff3b4d_0
  - readline=8.0=h7b6447c_0
  - scipy=1.4.1=py38h0b6359f_0
  - setuptools=47.1.1=py38_0
  - six=1.15.0=py_0
  - sqlite=3.31.1=h62c20be_1
  - tk=8.6.8=hbc83047_0
  - wheel=0.34.2=py38_0
  - xz=5.2.5=h7b6447c_0
  - zlib=1.2.11=h7b6447c_3
  - pip:
    - absl-py==0.9.0
    - jax==0.1.69
    - jaxlib==0.1.47
    - opt-einsum==3.2.1

I have also confirmed this on several different sets of hardware

skye commented 4 years ago

Thanks, I was able to repro at jax versions 0.1.67 and 0.1.69 (the latest), but this appears to be fixed at head. Lemme spin up a new jax release (or you can pip install directly from github for immediate gratification!).

lumip commented 4 years ago

Okay, that's great news! Any idea what the problem might have been, though (or rather, which change might have fixed it)?

skye commented 4 years ago

Please try jax 0.1.70, hot off the press. I'm not sure which change fixed it unfortunately. You can look at the all the new changes here: https://github.com/google/jax/compare/jax-v0.1.69...jax-v0.1.70

skye commented 4 years ago

I'm gonna close this issue, but please reopen or comment if it's not fixed for you.

lumip commented 4 years ago

Was just able to verify that the problem is resolved in the new version for me, too! I did a bit of digging and it was the commit relating to PR #2260 that resolved it (and I guess the problem was related to issue #939 which was also mentioned in that PR). Anyway, all is good now, thanks again!

langosco commented 4 years ago

I'm having the same issue on jax version 0.1.76, jaxlib version 0.1.55. grad(f) evaluates just fine, but grad(jit(f))(x) returns NaN.

I'll try to post a minimal reproducible code sample, but it's a bit tricky since f is a neural network loss and the issue only manifests for particular parameter values.

Edit: jit(grad(f))(x) returns NaN as well (as you'd expect I guess).

mattjj commented 4 years ago

@langosco thanks for letting us know; please open a new issue when you find a way to repro!

eromoe commented 3 years ago

I get nan from svi , code is :

    def init_svi(self, X: DeviceArray, *, lr: float, **kwargs):
        """Initialize the SVI state

        Args:
            X: input data
            lr: learning rate
            kwargs: other keyword arguments for optimizer
        """
        self.optim = self.optim_builder(lr, **kwargs)
        self.svi = SVI(self.model, self.guide, self.optim, self.loss)
        svi_state = self.svi.init(self.rng_key, X)
        if self.svi_state is None:
            self.svi_state = svi_state
        return self

    def _fit(self, X: DeviceArray, n_epochs) -> float:
        @jit
        def train_epochs(svi_state, n_epochs):
            def train_one_epoch(_, val):
                loss, svi_state = val
                svi_state, loss = self.svi.update(svi_state, X)
                return loss, svi_state

            return lax.fori_loop(0, n_epochs, train_one_epoch, (0., svi_state))

        loss, self.svi_state = train_epochs(self.svi_state, n_epochs)
        return float(loss / X.shape[0])