cbg-ethz / bmi

Mutual information estimators and benchmark
https://cbg-ethz.github.io/bmi/
MIT License
33 stars 5 forks source link

Resolve Issue 161 #165

Open pawel-czyz opened 3 months ago

pawel-czyz commented 3 months ago

This PR aims to resolve #161.

Tasks:

Help highly appreciated! :slightly_smiling_face:

pawel-czyz commented 3 months ago

I've got stuck at creating mixtures of joint distributions (which, recall, are now tuples/lists, rather than single integers):

TypeError: Dimension value must be integer or None or have an __index__ method, got value 'TensorShape([2])' with type '<class 'tensorflow_probability.python.internal.backend.jax.gen.tensor_shape.TensorShape'>'

To reproduce, use the following code:


import tensorflow_probability as tfp
import jax
import jax.numpy as jnp

tfd = tfp.substrates.jax.distributions
tfb = tfp.substrates.jax.bijectors

key = jax.random.PRNGKey(42)

# This works
mix = tfd.Mixture(
    cat=tfd.Categorical(probs=jnp.asarray([0.3, 0.7])),
    components=[tfd.Normal(0.0, 1.0), tfd.Normal(1., 2.)],
)
print("Here it works. This is a sample: ", mix.sample(3, key))

mean = jnp.zeros(5)
covariance_matrix = jnp.eye(5)
dist = tfd.MultivariateNormalFullCovariance(loc=mean, covariance_matrix=covariance_matrix)

split_dist = tfd.TransformedDistribution(
    distribution=dist,
    bijector=tfb.Split((2, 3)),
)

bijectors = [tfb.Exp(), tfb.Sigmoid()]  # Note: has to be a list, for tuple doesn't work

bij = tfb.JointMap(bijectors)
tr_dist = tfd.TransformedDistribution(distribution=split_dist, bijector=bij)

print("There is an error:")
tfd.Mixture(
    cat=tfd.Categorical(probs=jnp.asarray([0.3, 0.7])),
    components=[split_dist, tr_dist],
)

I'm feeling a bit lost here. I've submitted an issue to the TFP repository.

Perhaps a workaround would be to define a custom mixture distribution, performing less checks at initialisation – in our case allowed array shapes can be made more strict. Some links which seem useful if we are going to use this approach: