Photrek / Nonlinear-Statistical-Coupling

Apache License 2.0
5 stars 1 forks source link

Apply tfp's sample_n function into MultivariateCoupledNormal class #29

Open Kevin-Chen0 opened 3 years ago

Kevin-Chen0 commented 3 years ago

In the nsc's tensor branch, sample_n function is the highest priority as we very likey to use tensor version of sample_n rather than the numpy version of it.

hxyue1 commented 3 years ago

I've modified the MultivariateCoupledNormal class so that the sample_n method works with tensors. However, I've written this in a different module called multivariate_coupled_normal_tf.py so we can have the original to compare.

Let me know if you'd like me to overwrite the original if that is not to your liking.

I have not looked through tensor integration with the other methods of the class, so I have no idea if they will work with tensor inputs.

Kevin-Chen0 commented 3 years ago

When doing the integration, can u also test this class with the VAE as well?