Open pawel-czyz opened 3 months ago
I've got stuck at creating mixtures of joint distributions (which, recall, are now tuples/lists, rather than single integers):
TypeError: Dimension value must be integer or None or have an __index__ method, got value 'TensorShape([2])' with type '<class 'tensorflow_probability.python.internal.backend.jax.gen.tensor_shape.TensorShape'>'
To reproduce, use the following code:
import tensorflow_probability as tfp
import jax
import jax.numpy as jnp
tfd = tfp.substrates.jax.distributions
tfb = tfp.substrates.jax.bijectors
key = jax.random.PRNGKey(42)
# This works
mix = tfd.Mixture(
cat=tfd.Categorical(probs=jnp.asarray([0.3, 0.7])),
components=[tfd.Normal(0.0, 1.0), tfd.Normal(1., 2.)],
)
print("Here it works. This is a sample: ", mix.sample(3, key))
mean = jnp.zeros(5)
covariance_matrix = jnp.eye(5)
dist = tfd.MultivariateNormalFullCovariance(loc=mean, covariance_matrix=covariance_matrix)
split_dist = tfd.TransformedDistribution(
distribution=dist,
bijector=tfb.Split((2, 3)),
)
bijectors = [tfb.Exp(), tfb.Sigmoid()] # Note: has to be a list, for tuple doesn't work
bij = tfb.JointMap(bijectors)
tr_dist = tfd.TransformedDistribution(distribution=split_dist, bijector=bij)
print("There is an error:")
tfd.Mixture(
cat=tfd.Categorical(probs=jnp.asarray([0.3, 0.7])),
components=[split_dist, tr_dist],
)
I'm feeling a bit lost here. I've submitted an issue to the TFP repository.
Perhaps a workaround would be to define a custom mixture distribution, performing less checks at initialisation – in our case allowed array shapes can be made more strict. Some links which seem useful if we are going to use this approach:
This PR aims to resolve #161.
Tasks:
JointDistribution
class.BMMSampler
has to be adjusted.float
is not further needed.JointDistribution
somewhere. Then, update it.Help highly appreciated! :slightly_smiling_face: