Open kmdalton opened 4 years ago
It looks like there is a rejection sampling approach from 1979, but I haven't read it closely to make sure it's the same/compatible pdf.https://www.researchgate.net/profile/Pandu_Tadikamalla/publication/227309949_Random_sampling_from_the_generalized_gamma_distribution/links/55de04bc08aeaa26af0f20a2.pdf
This could be a nice "first contribution" sized project, good way to get some exposure to TFP's rejection sampling lib.
Sure, done. In terms of sequencing, I'd suggest to first get the rejection sampler written & a PR committed, then build a Distribution class around that.
You could look at some example rejection samplers in e.g. https://cs.opensource.google/tensorflow/probability/+/master:tensorflow_probability/python/distributions/poisson.py;l=351 https://cs.opensource.google/tensorflow/probability/+/master:tensorflow_probability/python/distributions/gamma.py;l=654
The corresponding test files will be likely be using the statistical_testing library to verify the empirical cdf of a number of samples against a true cdf.
I found some info on the Amaroso distribution, which adds a fourth parameter to the generalized gamma parameterization linked in my original post.
According to this, Amaroso random samples can be generated from gamma distributed random samples by a simple transformation.
Maybe we don't have to implement this from scratch. Since TFP already has a reparameterized gamma distribution, could this family be implemented by extending that?
Yep! I believe GeneralizedGamma (as well as Amoroso) should be taking gamma samples and taking powers of that.
So, I think the following is an implementation of the Amoroso distribution using bijectors. At least, I was able to reproduce several figures from the Crooks Paper about this distribution. Is there anything wrong with doing an implementation this way? I mean obviously we'd want to build it up into a nice class with a good constructor and whatnot, but does this do everything that TFP requires in terms of standards?
import numpy as np
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow_probability import distributions as tfd
from tensorflow_probability import bijectors as tfb
def CrooksAmorosoFactory(a, theta, alpha, beta):
gamma = tfd.Gamma(alpha, 1.)
chain = tfb.Chain([
tfb.Exp(),
tfb.Scale(beta),
tfb.Shift(-tf.math.log(theta)),
tfb.Log(),
tfb.Shift(-a),
])
dist = tfd.TransformedDistribution(
gamma,
tfb.Invert(chain),
)
return dist
x = np.linspace(0, 3, 1000)
# Reproduce figure 1 from crooks
plt.figure()
fignum = 1
dist_name = "Stretched Exponential"
a,theta,alpha,beta = 0., 1., 1.,np.array([1., 2., 3., 4., 5.]).astype(np.float32)
labels = [f"$\\beta={int(i)}$" for i in beta]
plt.title(f"Figure {fignum} -- {dist_name}\n$Amoroso(x|{a},{theta},{alpha},\\beta)$")
dist = CrooksAmorosoFactory(a, theta, alpha, beta)
plt.plot(x, dist.prob(x[:,None]).numpy())
plt.legend(labels)
# Reproduce figure 2 from crooks
plt.figure()
fignum = 2
dist_name = ""
a,theta,alpha,beta = 0., 1., 2.,np.array([1., 2., 3., 4.]).astype(np.float32)
labels = [f"$\\beta={int(i)}$" for i in beta]
plt.title(f"Figure {fignum} -- {dist_name}\n$Amoroso(x|{a},{theta},{alpha},\\beta)$")
dist = CrooksAmorosoFactory(a, theta, alpha, beta)
plt.plot(x, dist.prob(x[:,None]).numpy())
plt.legend(labels)
# Reproduce figure 5 from crooks
plt.figure()
fignum = 5
dist_name = ""
a,theta,alpha,beta = 0., 1., np.array([0.5, 1., 1.5]).astype(np.float32), 2
labels = [f"$\\beta={int(i)}$" for i in alpha]
plt.title(f"Figure {fignum} -- {dist_name}\n$Amoroso(x|{a},{theta},{alpha},\\beta)$")
dist = CrooksAmorosoFactory(a, theta, alpha, beta)
plt.plot(x, dist.prob(x[:,None]).numpy())
plt.legend(labels)
# Reproduce figure 6 from crooks
plt.figure()
fignum = 6
dist_name = ""
a,theta,alpha,beta = 0., 1., 2.,np.array([-1., -2., -3.]).astype(np.float32)
labels = [f"$\\beta={int(i)}$" for i in beta]
plt.title(f"Figure {fignum} -- {dist_name}\n$Amoroso(x|{a},{theta},{alpha},\\beta)$")
dist = CrooksAmorosoFactory(a, theta, alpha, beta)
plt.plot(x, dist.prob(x[:,None]).numpy())
plt.legend(labels)
plt.show()
I think I can probably handle this implementation by extending tfp.TransformedDistribution
. It'd be a great help if someone can point me to a gold standard distribution which has been implemented this way as an example.
Specifically, I would like to implement:
tfd.Amoroso
which is the 4 parameter generalization of the gamma distribution.tfd.Stacy
which is a 3 parameter generalization and special case of the Amoroso distribution. I can just implement this by extending tfd.Amoroso
Here is a concrete, minimal example of how I intend to implement tfd.Amoroso
:
class Amoroso(tfd.TransformedDistribution):
def __init__(self,
a,
theta,
alpha,
beta,
validate_args=False,
allow_nan_stats=True,
name='Amoroso'):
parameters = dict(locals())
with tf.name_scope(name) as namee:
self._a = tensor_util.convert_nonref_to_tensor(a)
self._theta = tensor_util.convert_nonref_to_tensor(theta)
self._alpha = tensor_util.convert_nonref_to_tensor(alpha)
self._beta = tensor_util.convert_nonref_to_tensor(beta)
gamma = tfd.Gamma(alpha, 1.)
chain = tfb.Invert(tfb.Chain([
tfb.Exp(),
tfb.Scale(beta),
tfb.Shift(-tf.math.log(theta)),
tfb.Log(),
tfb.Shift(-a),
]))
super().__init__(
distribution=gamma,
bijector=chain,
validate_args=validate_args,
parameters=parameters,
name=name)
@property
def a(self):
return self._a
@property
def theta(self):
return self._theta
@property
def alpha(self):
return self._alpha
@property
def beta(self):
return self._beta
I will implement the following methods:
tfd.Stacy.kl_divergence(self, other)
The Stacy distribution has a published analytical kl div. tfd.{Amoroso,Stacy}.mean
There is an analytical expression for the mean and variance although it isn't defined everywhere. tfd.{Amoroso,Stacy}.variance
tfd.{Amoroso,Stacy}.stddev
I should be able to test against the pdfs of some special cases which are already implemented in TFP.
Gamma
Weibull
HalfNormal
Exponential
Chi2
...
Does that all sound reasonable? Am I missing some obvious reason why implementing by bijection from a standard gamma is a bad plan?
Implementing as a sequence of bijectors, as well as the testing sounds good. I will note, that you might want to write the log_prob for this distribution specifically. The reason is that there might be simplifications / numerically stable changes you can make by writing the log_prob explicitly. For sampling, you probably won't do much better then applying the sequence of transformations, unless you write a specialized sampler.
Good point, @srvasude , I will implement log_prob
as well. Thanks for the feedback!
I have working implementations of the Amoso and Stacy distributions over here and a few basic tests. I am certain my implementation is not up to TFP's style standards. I will have a look through the CONTRIBUTING.md. Hopefully I can clean up the code and put together a pull request sometime soon.
Can anyone recommend any particularly clean distribution implementations for me to look at as an example?
Here's my implementation, along with a Weibull based on it
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.distributions import (
distribution, kullback_leibler,
TransformedDistribution, JointDistributionNamed)
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import distribution_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static
from tensorflow_probability.python.internal import reparameterization
from tensorflow_probability.python.internal import tensor_util
from mederrata_stage1.tools.tf import (
log_normalized_upper_igamma)
tfd = tfp.distributions
tfb = tfp.bijectors
class GeneralizedGamma(distribution.Distribution):
"""Generalized Gamma distribution
Following the wikipedia parameterization
https://en.wikipedia.org/wiki/Generalized_gamma_distribution
f(x; a=scale, d=shape, p=exponent) =
\frac{(p/a^d) x^{d-1} e^{-(x/a)^p}}{\Gamma(d/p)},
location =
Arguments:
distribution {[type]} -- [description]
"""
def __init__(self,
scale, shape, exponent,
validate_args=False,
allow_nan_stats=True,
name='GeneralizedGamma'):
parameters = dict(locals())
with tf.name_scope(name) as name:
dtype = dtype_util.common_dtype(
[scale, shape, exponent], dtype_hint=tf.float32)
self._scale = tensor_util.convert_nonref_to_tensor(
scale, dtype=dtype, name='scale')
self._shape = tensor_util.convert_nonref_to_tensor(
shape, dtype=dtype, name='shape')
self._exponent = tensor_util.convert_nonref_to_tensor(
exponent, dtype=dtype, name='exponent')
super(GeneralizedGamma, self).__init__(
dtype=dtype,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
reparameterization_type=(
reparameterization.FULLY_REPARAMETERIZED
),
parameters=parameters,
name=name)
def _mean(self):
return self.scale * tf.math.exp(
tf.math.lgamma((self.shape + 1.)/self.exponent)
- tf.math.lgamma(self.shape/self.exponent)
)
def _variance(self):
return self._scale**2 * (
tf.math.exp(
tf.math.lgamma((self.shape+2.)/self.exponent)
- tf.math.lgamma(self.shape/self.exponent)
)
- tf.math.exp(
2*(
tf.math.lgamma((self.shape+1.)/self.exponent)
- tf.math.lgamma(self.shape/self.exponent)
)
)
)
def _cdf(self, x):
return tf.math.igamma(self.shape/self.exponent,
(x/self.scale)**self.exponent) * tf.exp(
-tf.math.lgamma(self.shape/self.exponent)
)
def _log_prob(self, x, scale=None, shape=None, exponent=None):
scale = convert_nonref_to_tensor(
self.scale if scale is None else scale)
shape = convert_nonref_to_tensor(
self.shape if shape is None else shape)
exponent = convert_nonref_to_tensor(
self.exponent if exponent is None else exponent)
log_unnormalized_prob = (
tf.math.xlogy(shape-1., x) - (x/scale)**exponent)
log_prefactor = (
tf.math.log(exponent) - tf.math.xlogy(shape, scale)
- tf.math.lgamma(shape/exponent))
return log_unnormalized_prob + log_prefactor
def _entropy(self):
scale = tf.convert_to_tensor(self.scale)
shape = tf.convert_to_tensor(self.shape)
exponent = tf.convert_to_tensor(self.exponent)
return (
tf.math.log(excale) + tf.math.lgamma(shape/exponent)
- tf.math.log(exponent) + shape/exponent
+ (1.0 - shape)/exponent*tf.math.digamma(shape/exponent)
)
def _stddev(self):
return tf.math.sqrt(self._variance())
def _default_event_space_bijector(self):
return softplus_bijector.Softplus(validate_args=self.validate_args)
def _sample_control_dependencies(self, x):
assertions = []
if not self.validate_args:
return assertions
assertions.append(assert_util.assert_non_negative(
x, message='Sample must be non-negative.'))
return assertions
@property
def scale(self):
return self._scale
@property
def shape(self):
return self._shape
@property
def exponent(self):
return self._exponent
def _batch_shape_tensor(self, scale=None, shape=None, exponent=None):
return prefer_static.broadcast_shape(
prefer_static.shape(
self.scale if scale is None else scale),
prefer_static.shape(self.shape if shape is None else shape),
prefer_static.shape(
self.exponent if exponent is None else exponent))
def _batch_shape(self):
return tf.broadcast_static_shape(
self.scale.shape,
self.shape.shape)
def _event_shape_tensor(self):
return tf.constant([], dtype=tf.int32)
def _sample_n(self, n, seed=None):
"""Sample based on transforming Gamma RVs
Arguments:
n {int} -- [description]
Keyword Arguments:
seed {int} -- [description] (default: {None})
Returns:
[type] -- [description]
"""
gamma_samples = tf.random.gamma(
shape=[n],
alpha=self.shape/self.exponent,
beta=1.,
dtype=self.dtype,
seed=seed
)
ggamma_samples = (
self.scale*tf.math.exp(tf.math.log(gamma_samples)/self.exponent)
)
return ggamma_samples
def _event_shape(self):
return tf.TensorShape([])
def _parameter_control_dependencies(self, is_init):
if not self.validate_args:
return []
assertions = []
if is_init != tensor_util.is_ref(self.scale):
assertions.append(assert_util.assert_positive(
self.scale,
message='Argument `scale` must be positive.'))
if is_init != tensor_util.is_ref(self.shape):
assertions.append(assert_util.assert_positive(
self.shape,
message='Argument `shape` must be positive.'))
if is_init != tensor_util.is_ref(self.exponent):
assertions.append(assert_util.assert_positive(
self.exponent,
message='Argument `exponent` must be positive.'))
return assertions
class Weibull(GeneralizedGamma):
def __init__(self,
scale, shape,
validate_args=False,
allow_nan_stats=True,
name='Weibull'):
parameters = dict(locals())
with tf.name_scope(name) as name:
dtype = dtype_util.common_dtype(
[scale, shape], dtype_hint=tf.float32)
self._scale = tensor_util.convert_nonref_to_tensor(
scale, dtype=dtype, name='scale')
self._shape = tensor_util.convert_nonref_to_tensor(
shape, dtype=dtype, name='shape')
self._exponent = tensor_util.convert_nonref_to_tensor(
shape, dtype=dtype, name='exponent')
super(GeneralizedGamma, self).__init__(
dtype=dtype,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
reparameterization_type=(
reparameterization.FULLY_REPARAMETERIZED
),
parameters=parameters,
name=name)
@property
def scale(self):
return self._scale
@property
def shape(self):
return self._shape
Hi @joshchang. Feel free to open a PR with your change if you are interested!
In X-ray crystallography, the most important prior distributions include two special cases of the generalized gamma distribtion. I am very keen to try this parameterization of the variational distritribution in my research project. How hard would it be for the TFP devs to implement this distribution? What is the likelihood of it being available in the near future?
Unrelated: TFP is a great package. Nobody else has so many useful reparameterized distributions. Thanks!