Open jatentaki opened 3 years ago
Well, looks like the culprit is here: https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/distributions/von_mises_fisher.py#L341.
It seems that in TF, l2_normalize([0., 0., 0.,]) returns [0., 0., 0.], which is fine for this purpose, but I guess in Jax it returns nan
s instead. We should probably patch that for consistency's sake.
Is this evidence of an accuracy issue, though? If mean_direction
is close to [1., 0., 0.], the 0th dimension of the reflection vector u
will be computed very inaccurately due to catastrophic cancellation. Will that cause observably bad sampling?
I have also encounter this issue. I performed a sweep across direction vectors very close to [1.,0]
, and did not observe any bad sampling for directions very close to the unit x-direction. Again, I receive nans when I try to sample from exactly [1.,0.]
import jax.config
jax.config.update("jax_enable_x64", True)
import jax.random as jr
import numpy as np
import tensorflow_probability as tfp
import tensorflow_probability.substrates.jax.distributions as tfd
import matplotlib.pyplot as plt
n_sweep = 225
n_samples = 100
kappa = 50
x_dir = np.array([1.,0])
# Sweep across mean direction vectors close to [1.,0] for a fixed concentration
rad_sweep = np.linspace(-(1e-64), 1e-64, n_sweep) # angles [rad], shape (n_sweep,2)
dir_sweep = np.stack([np.cos(rad_sweep), np.sin(rad_sweep)], axis=-1) # direction vectors, shape (n_sweep,2)
# Sample. Resulting shape: (n_samples, n_sweep, 2)
xs = tfd.VonMisesFisher(dir_sweep, kappa).sample((n_samples,), jr.PRNGKey(50))
# ----------------------------------------------------------------------------
idx = n_sweep // 2 # Index closest to [1.,0] is in middle of sweep range
print('Values of angles and direction vectors closest to [1.,0]')
for rad, dir in zip(rad_sweep[idx-3:idx+3], dir_sweep[idx-3:idx+3]):
print('\tAngle: {:+.2e}, direction: [{:+.2e},{:+.2e}]'.format(rad, dir[0], dir[1]))
print()
# Plot average angular distance of samples from direction [1.,0]
dtheta = np.arctan(np.abs(np.cross(xs, x_dir)) / np.dot(xs, x_dir))
fig, axs = plt.subplots(1, 2, figsize=(10,3))
axs[0].plot(rad_sweep, dtheta.mean(axis=0))
axs[0].set_ylabel('Mean angular difference [rad]')
axs[1].plot(rad_sweep[idx-5:idx+5], dtheta[:,idx-5:idx+5].mean(axis=0)) # Zoom in
for ax in axs:
ax.set_xlabel('Angle [rad]')
plt.show()
# Plot samples with mean_dir closest to [1.,0]
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10,5))
axs[0].plot(xs[:, idx, 0], xs[:, idx, 1], 'k.')
axs[0].set_title('theta={:.2e}'.format(rad_sweep[idx]))
axs[1].plot(xs[:, idx+1, 0], xs[:, idx+1, 1], 'k.')
axs[1].set_title('theta={:.2e}'.format(rad_sweep[idx+1]))
for ax in axs:
ax.set_ylim(-1.5, 1.5); ax.set_xlim(-1.5, 1.5)
ax.set_xlabel("$x_1$"); ax.set_ylabel("$x_2$")
ax.set_aspect("equal");
plt.show()
As in the title: VMF samples
nan
withmean_direction=[1., 0., 0.]
and works with other unit directions. Example:originally reported in https://github.com/pyro-ppl/numpyro/issues/859