Open henrypinkard opened 1 year ago
@henrypinkard,
The reinterpreted_batch_ndims parameter controls the number of batch dims which are absorbed as event dims; reinterpreted_batch_ndims <= len(batch_shape).
For example, the _logprob function entails a reduce_sum over the rightmost reinterpreted_batch_ndims
after calling the base distribution's log_prob. Also the batch dimension(s) index independent distributions, the resultant multivariate will have independent components.
Also reinterpreted_batch_ndims is part of tensorflow probability(tfp), So i request you to check in this repo for more assistance. Thank you!
Thanks for the explanation. I also opened an issue on TFP (https://github.com/tensorflow/probability/issues/1679), though no response yet
@henrypinkard, Could ypu please feel free to move this issue to closed status, since it is already being tracked there? Thank you!
This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.
Still no activity on the TFP issue or clarity no how this might be fixed
I'm trying to make some dense layers that output the parameters of Gaussian mixture (a mixture density network). I want to run a batch of data through the network (for speed), get out a batch of distributions, and then slice to work with only some elements of the batch at a time. If I were doing this with just tfp, I call:
This works as expected giving
However when I try the same thing with a mixture density network in Keras I get an error
Gives me this cryptic error:
v2.11.0-rc2-17-gd5b57ca93e5 2.11.0 tensorflow-probability 0.19.0 Python 3.10.6 Ubuntu 18.04
Any ideas why this is happening and how to fix?