Hi!
I have a use case where I am doing projected gradient descent on a Multinomial distribution, with an array of parameters being a categorical distribution $p$.
However, due to what I am assuming is approximations of the pmf of a multinomial distribution, the gradient of this distribution is incorrect. Specifically, when values are not sampled, they still have a nonzero corresponding value in the gradient. An example is when we are considering the "4 choose 2" problem. If we have 4 parameters $p_1, ..., p_4$, and our sample $x$ only includes samples of object 4, we would expect the gradient with respect to the log pmf to be zero everywhere except for index $4$ (or $3$ in python terms). In a similar situation, the same issue can be replicated in the following code.
I believe this can be remedied by writing a custom JVP but I'm not 100% sure how that works if I am being completely honest. For now I will implement the gradient by hand, but I'd appreciate it if this could be worked on!
Thanks! and please let me know if there is something I am missing conceptually.
A friend pointed out that this could be due to the re-normalization being added to the computational graph when the Multinomial distribution is initialized
Hi! I have a use case where I am doing projected gradient descent on a Multinomial distribution, with an array of parameters being a categorical distribution $p$.
However, due to what I am assuming is approximations of the pmf of a multinomial distribution, the gradient of this distribution is incorrect. Specifically, when values are not sampled, they still have a nonzero corresponding value in the gradient. An example is when we are considering the "4 choose 2" problem. If we have 4 parameters $p_1, ..., p_4$, and our sample $x$ only includes samples of object 4, we would expect the gradient with respect to the log pmf to be zero everywhere except for index $4$ (or $3$ in python terms). In a similar situation, the same issue can be replicated in the following code.
I believe this can be remedied by writing a custom JVP but I'm not 100% sure how that works if I am being completely honest. For now I will implement the gradient by hand, but I'd appreciate it if this could be worked on!
Thanks! and please let me know if there is something I am missing conceptually.