Closed ricardoV94 closed 1 year ago
Thanks Ricardo!
I dug in a little, and just to mention some things that may be obvious:
The log_prob
is correct, the prob
is wrong:
from scipy.integrate import trapezoid
x = tf.linspace(0., 2., 10_000)
y = 2 - x
pts = tf.concat((x[..., None], y[..., None]), -1)
trapezoid(scaled_dir.prob(pts), x=x), trapezoid(tf.exp(scaled_dir.log_prob(pts)), x=x)
# (0.5, 1.0)
Just the absolute hackiest way to fix this as a user is to call
del tfd.Dirichlet._prob
float(scaled_dir.prob([0.2, 1.8])), np.exp(scaled_dir.log_prob([0.2, 1.8]))
# (0.4860000014305115, 0.48600003)
That might be a reasonable quick fix in TFP, too : Dirichlet._prob
only exists to append a sentence onto the docstring.
I'm trying a fix that just uses tf.exp(self.log_prob(...))
in case there is a injective bijector, but not sure if that will break a bunch of tests. Otherwise might need someone more well versed in numerics to look at this.