cbg-ethz / bmi

Mutual information estimators and benchmark
https://cbg-ethz.github.io/bmi/
MIT License
26 stars 4 forks source link

JointDistribution wraps and unwraps X and Y #161

Open pawel-czyz opened 2 weeks ago

pawel-czyz commented 2 weeks ago

Currently the JointDistribution wraps and unwraps X and Y samples into one array XY by concatenation and slicing.

This is suboptimal: for example, X and Y need to have the same dtype and working with continuous and categorical variables requires manual casting.

Instead, we can use JointDistribution from TFP on JAX.

pawel-czyz commented 2 weeks ago

This has to be implemented after #143 has been resolved.

It should be then a minor change, as there's a Split bijector in TFP on JAX.