Open alicanb opened 6 years ago
You can get OneHotCategorical
for free from Categorical
. There is also Multinomial
but I believe that is intractable.
I will take this up for the time being. @fritzo @alicanb
It might be a good idea to build a KL table to see which pairs we are missing, I tried to do it today, but couldn't do it in github. Maybe design doc ?
On Mon, Jan 29, 2018, 10:58 PM Vishwak Srinivasan notifications@github.com wrote:
I will take this up for the time being. @fritzo https://github.com/fritzo @alicanb https://github.com/alicanb
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/probtorch/pytorch/issues/91#issuecomment-361469902, or mute the thread https://github.com/notifications/unsubscribe-auth/ABCw1tNhrdgcr_a2VRX94k22GZgMD6W0ks5tPpNmgaJpZM4RevaV .
@alicanb Yeah, having a table would be a great idea. I don't know if it should be added in the Design doc. I think it will be better to have it in Markdown too.
I think I will look at this later. Sorry @alicanb
@vishwakftw Do you think it's easy to compute KL(Binomial(n,p), Poisson(lambda))
? @alicanb and I were looking at this the other day but got stuck.
At a glance, I think there might be issues with the finite sum of the exponential term in the Poisson's pmf. Furthermore, the nCk
term could also cause issues.
I will have a detailed look at it and revert to you tomorrow if that is fine.
We can always calculate this numerically since binomial has enumerate_support
, in fact I remember discussing adding default KL for distributions with enumerate_support
with somebody (@rachtsingh was it you?)
@fritzo I am stuck at one term in the summation. This is the one:
@alicanb It should be possible for finite support with some extra ops, but getting a closed form solution might be hard.
The summation should actually not be hard to implement in PyTorch. Something like this should do the trick:
def this_sum(n, p):
factor = n.lgamma.exp()
valrange = torch.arange(0, n + 1).lgamma()
return factor * (reverse(valrange) / (valrange + reverse(valrange)).exp()) * p.pow(valrange) * (1 - p).pow(reverse(valrange))).sum(-1)
@fritzo @alicanb We could use the Ramanujan's approximation for computing this sum. The asymptotic error is O(1/n^{3}) . Let me know what you think.
EDIT: it might be hard to denote moments using this approximation, so we can also settle for Stirling's approximation.
@vishwakftw I like the idea of implementing the exact sum. It should be cheap on GPUs, and I've seldom seen binomial used with large n
.
@fritzo Sure, we could do that. I will send in a PR soon.
Give me a few hours, I already have something that tackles Binom-Binom, Binom-Geo, and Binom-Poi KLs
Amazing!! Thanks @alicanb .
@alicanb @fritzo I have prepared a script to show existing KL-div pairs in Markdown format. This is the script.
@vishwakftw That's so cool, is there a way we can put that in docs?
It should be possible, but the issue is that the table is too big.
Bernoulli
Binomial
Categorical
Geometric
Multinomial
(intractable)Poisson
(after merge withmaster
)OneHotCategorical
Are there more?