Open Kevin-Chen0 opened 3 years ago
I've modified the MultivariateCoupledNormal class so that the sample_n method works with tensors. However, I've written this in a different module called multivariate_coupled_normal_tf.py so we can have the original to compare.
Let me know if you'd like me to overwrite the original if that is not to your liking.
I have not looked through tensor integration with the other methods of the class, so I have no idea if they will work with tensor inputs.
When doing the integration, can u also test this class with the VAE as well?
In the nsc's tensor branch,
sample_n
function is the highest priority as we very likey to use tensor version ofsample_n
rather than the numpy version of it.