Closed lumip closed 4 years ago
I'm not able to reproduce this -- can you make sure you're using the latest jax
and jaxlib
versions? Also are you using CPU, GPU, or TPU? (You can determine this via jax.devices()
)
Sorry, forgot to mention those details. We are running on CPU and originally this was for
jax==0.1.67
jaxlib==0.1.47
I just did a fresh install in a new environment and still observe the same problem. This is the conda env export
:
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- blas=1.0=mkl
- ca-certificates=2020.1.1=0
- certifi=2020.4.5.1=py38_0
- intel-openmp=2020.1=217
- ld_impl_linux-64=2.33.1=h53a641e_7
- libedit=3.1.20181209=hc058e9b_0
- libffi=3.3=he6710b0_1
- libgcc-ng=9.1.0=hdf63c60_0
- libgfortran-ng=7.3.0=hdf63c60_0
- libstdcxx-ng=9.1.0=hdf63c60_0
- mkl=2020.1=217
- mkl-service=2.3.0=py38he904b0f_0
- mkl_fft=1.0.15=py38ha843d7b_0
- mkl_random=1.1.1=py38h0573a6f_0
- ncurses=6.2=he6710b0_1
- numpy=1.18.1=py38h4f9e942_0
- numpy-base=1.18.1=py38hde5b4d6_1
- openssl=1.1.1g=h7b6447c_0
- pip=20.0.2=py38_3
- python=3.8.3=hcff3b4d_0
- readline=8.0=h7b6447c_0
- scipy=1.4.1=py38h0b6359f_0
- setuptools=47.1.1=py38_0
- six=1.15.0=py_0
- sqlite=3.31.1=h62c20be_1
- tk=8.6.8=hbc83047_0
- wheel=0.34.2=py38_0
- xz=5.2.5=h7b6447c_0
- zlib=1.2.11=h7b6447c_3
- pip:
- absl-py==0.9.0
- jax==0.1.69
- jaxlib==0.1.47
- opt-einsum==3.2.1
I have also confirmed this on several different sets of hardware
Thanks, I was able to repro at jax
versions 0.1.67 and 0.1.69 (the latest), but this appears to be fixed at head. Lemme spin up a new jax
release (or you can pip install directly from github for immediate gratification!).
Okay, that's great news! Any idea what the problem might have been, though (or rather, which change might have fixed it)?
Please try jax 0.1.70, hot off the press. I'm not sure which change fixed it unfortunately. You can look at the all the new changes here: https://github.com/google/jax/compare/jax-v0.1.69...jax-v0.1.70
I'm gonna close this issue, but please reopen or comment if it's not fixed for you.
Was just able to verify that the problem is resolved in the new version for me, too! I did a bit of digging and it was the commit relating to PR #2260 that resolved it (and I guess the problem was related to issue #939 which was also mentioned in that PR). Anyway, all is good now, thanks again!
I'm having the same issue on jax version 0.1.76, jaxlib version 0.1.55. grad(f)
evaluates just fine, but grad(jit(f))(x)
returns NaN.
I'll try to post a minimal reproducible code sample, but it's a bit tricky since f is a neural network loss and the issue only manifests for particular parameter values.
Edit: jit(grad(f))(x)
returns NaN as well (as you'd expect I guess).
@langosco thanks for letting us know; please open a new issue when you find a way to repro!
I get nan from svi , code is :
def init_svi(self, X: DeviceArray, *, lr: float, **kwargs):
"""Initialize the SVI state
Args:
X: input data
lr: learning rate
kwargs: other keyword arguments for optimizer
"""
self.optim = self.optim_builder(lr, **kwargs)
self.svi = SVI(self.model, self.guide, self.optim, self.loss)
svi_state = self.svi.init(self.rng_key, X)
if self.svi_state is None:
self.svi_state = svi_state
return self
def _fit(self, X: DeviceArray, n_epochs) -> float:
@jit
def train_epochs(svi_state, n_epochs):
def train_one_epoch(_, val):
loss, svi_state = val
svi_state, loss = self.svi.update(svi_state, X)
return loss, svi_state
return lax.fori_loop(0, n_epochs, train_one_epoch, (0., svi_state))
loss, self.svi_state = train_epochs(self.svi_state, n_epochs)
return float(loss / X.shape[0])
We have recently encountered a strange bug which manifested itself in
nan
values resulting from the computation - specifically in the gradient computation - under specific circumstances which I'll try to explain below.General Setting
Our general setting is that we try to fit a probabilistic mixture model using stochastic variational inference (and the numpyro framework building on jax) but we found our problem to be unrelated to any specifics of that. The following shows a code snippet which we simplified to the point where it does not depend on numpyro and does only a single log-likelihood evaluation but still exhibits the issue.
Code to reproduce
Our Observations
As already mentioned, we have noticed that under certain conditions, the gradients for this contain
nan
values. Most notably, this only occurs whenjit
is applied. There are further comments in the code that indicate for which variations we observenan
values and for which we do not. These are presented in blocks of subsequent lines, which are slight variations of each other, and the comments mean the following: If in all the blocks a variant annotated to producenan
s is active, we will observenan
values in the gradient. If in any block a variant annotated to befine
, there will be nonan
values, regardless of the other blocks.A summary of the observations
temp = log_phis
, we see no problem, but any arithmetic done withtemp
before thelogsumexp
will causenan
valuestemp = log_phis[0:len(log_phis)]
, in which case everything is fine.log_phis
is a vector, so these two variants should be equivalentlog_prob
values are summed up to getlog_phis
affects whether we observenan
values or notk
, especiallyk=1
, are fine. some larger value are also fine (sometimes)nan
values only occur whenjit
is appliednan
values, these only occur for the gradientsnan
values occur, the only occur for some batch instances, but then always affect all sites/values of the gradientNote that we need per-example gradients and that is where we first noticed the problem, but we noticed that the problem also occurs with regular batch gradients. We left both variants in the code above.
We are somewhat baffled by this behaviour as it seems quite odd indeed. From our observations, we suspect some it might be somehow shape/broadcasting or
logsumexp
(probably the combination of both) related, but we have no convincing evidence for that. We thus wanted to ask if anything like this has been observed before or if you - as the experts on jax - could shed some light into what might be going wrong here.