Not sure if this is the right name for this distribution (Concrete / GumbelSoftmax are other ideas), but this is what Tensorflow calls it. This PR uses the transforms machinery :)
I had to edit TransformedDistribution's log_prob method to take into account event_shape - this is probably not the right way to do it, but a quick first try that makes it work.
Not sure if this is the right name for this distribution (Concrete / GumbelSoftmax are other ideas), but this is what Tensorflow calls it. This PR uses the transforms machinery :)
I had to edit
TransformedDistribution
'slog_prob
method to take into accountevent_shape
- this is probably not the right way to do it, but a quick first try that makes it work.cc @fritzo