tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.23k stars 1.09k forks source link

VonMisesFisher returns NaN with mean_direction=[1., 0., 0.] with lax backend #1211

Open jatentaki opened 3 years ago

jatentaki commented 3 years ago

As in the title: VMF samples nan with mean_direction=[1., 0., 0.] and works with other unit directions. Example:

import tensorflow_probability.substrates.jax as tfp
from jax import random

vmf_good = tfp.distributions.VonMisesFisher(
    mean_direction=[0., 1., 0.],
    concentration=1.,
)

vmf_bad = tfp.distributions.VonMisesFisher(
    mean_direction=[1., 0., 0.],
    concentration=1.,
)

print(vmf_good.sample(sample_shape=2, seed=random.PRNGKey(0)))
print(vmf_bad.sample(sample_shape=2, seed=random.PRNGKey(0)))

>>>2021-01-04 18:39:35.373535: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
/home/jatentaki/Storage/jatentaki/miniconda3/envs/tfp/lib/python3.8/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
[[-0.9372401  -0.17581224 -0.30111626]
 [-0.93092006 -0.36028326 -0.05986603]]
[[nan nan nan]
 [nan nan nan]]

originally reported in https://github.com/pyro-ppl/numpyro/issues/859

axch commented 3 years ago

Well, looks like the culprit is here: https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/distributions/von_mises_fisher.py#L341.

It seems that in TF, l2_normalize([0., 0., 0.,]) returns [0., 0., 0.], which is fine for this purpose, but I guess in Jax it returns nans instead. We should probably patch that for consistency's sake.

Is this evidence of an accuracy issue, though? If mean_direction is close to [1., 0., 0.], the 0th dimension of the reflection vector u will be computed very inaccurately due to catastrophic cancellation. Will that cause observably bad sampling?

ezhang94 commented 3 years ago

I have also encounter this issue. I performed a sweep across direction vectors very close to [1.,0], and did not observe any bad sampling for directions very close to the unit x-direction. Again, I receive nans when I try to sample from exactly [1.,0.]

Results

image

Code

import jax.config
jax.config.update("jax_enable_x64", True)

import jax.random as jr
import numpy as np

import tensorflow_probability as tfp
import tensorflow_probability.substrates.jax.distributions as tfd

import matplotlib.pyplot as plt

n_sweep = 225
n_samples = 100
kappa = 50
x_dir = np.array([1.,0])

# Sweep across mean direction vectors close to [1.,0] for a fixed concentration
rad_sweep = np.linspace(-(1e-64), 1e-64, n_sweep) # angles [rad], shape (n_sweep,2)
dir_sweep = np.stack([np.cos(rad_sweep), np.sin(rad_sweep)], axis=-1) # direction vectors, shape (n_sweep,2)

# Sample. Resulting shape: (n_samples, n_sweep, 2)
xs = tfd.VonMisesFisher(dir_sweep, kappa).sample((n_samples,), jr.PRNGKey(50))

# ----------------------------------------------------------------------------
idx = n_sweep // 2 # Index closest to [1.,0] is in middle of sweep range

print('Values of angles and direction vectors closest to [1.,0]')
for rad, dir in zip(rad_sweep[idx-3:idx+3], dir_sweep[idx-3:idx+3]):
    print('\tAngle: {:+.2e}, direction: [{:+.2e},{:+.2e}]'.format(rad, dir[0], dir[1]))
print()

# Plot average angular distance of samples from direction [1.,0]
dtheta = np.arctan(np.abs(np.cross(xs, x_dir)) / np.dot(xs, x_dir))
fig, axs = plt.subplots(1, 2, figsize=(10,3))
axs[0].plot(rad_sweep, dtheta.mean(axis=0))
axs[0].set_ylabel('Mean angular difference [rad]')

axs[1].plot(rad_sweep[idx-5:idx+5], dtheta[:,idx-5:idx+5].mean(axis=0)) # Zoom in

for ax in axs:
    ax.set_xlabel('Angle [rad]')
plt.show()

# Plot samples with mean_dir closest to [1.,0]
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10,5))
axs[0].plot(xs[:, idx, 0], xs[:, idx, 1], 'k.')
axs[0].set_title('theta={:.2e}'.format(rad_sweep[idx]))

axs[1].plot(xs[:, idx+1, 0], xs[:, idx+1, 1], 'k.')
axs[1].set_title('theta={:.2e}'.format(rad_sweep[idx+1]))

for ax in axs:
    ax.set_ylim(-1.5, 1.5); ax.set_xlim(-1.5, 1.5)
    ax.set_xlabel("$x_1$"); ax.set_ylabel("$x_2$")
    ax.set_aspect("equal");
plt.show()