Open pawel-czyz opened 5 months ago
Currently the JointDistribution wraps and unwraps X and Y samples into one array XY by concatenation and slicing.
JointDistribution
X
Y
XY
This is suboptimal: for example, X and Y need to have the same dtype and working with continuous and categorical variables requires manual casting.
dtype
Instead, we can use JointDistribution from TFP on JAX.
This has to be implemented after #143 has been resolved.
It should be then a minor change, as there's a Split bijector in TFP on JAX.
Split
Currently the
JointDistribution
wraps and unwrapsX
andY
samples into one arrayXY
by concatenation and slicing.This is suboptimal: for example,
X
andY
need to have the samedtype
and working with continuous and categorical variables requires manual casting.Instead, we can use
JointDistribution
from TFP on JAX.