google-deepmind / distrax

Apache License 2.0
535 stars 32 forks source link

Differentiating the Log PMF of a Multinomial Distribution #277

Closed nmonette closed 2 months ago

nmonette commented 2 months ago

Hi! I have a use case where I am doing projected gradient descent on a Multinomial distribution, with an array of parameters being a categorical distribution $p$.

However, due to what I am assuming is approximations of the pmf of a multinomial distribution, the gradient of this distribution is incorrect. Specifically, when values are not sampled, they still have a nonzero corresponding value in the gradient. An example is when we are considering the "4 choose 2" problem. If we have 4 parameters $p_1, ..., p_4$, and our sample $x$ only includes samples of object 4, we would expect the gradient with respect to the log pmf to be zero everywhere except for index $4$ (or $3$ in python terms). In a similar situation, the same issue can be replicated in the following code.

import jax
import jax.numpy as jnp
import distrax
from functools import partial

@partial(jax.grad, has_aux=True)
        def sample_fn(p, rng):
            dist = distrax.Multinomial(32, probs=p)
            idxs, lp = dist.sample_and_log_prob(seed=rng)
            return lp, idxs

rng = jax.random.key(0)
x = jnp.full(4000, 1 / 4000)
x_lp, x_counts = sample_fn(x, rng)
jax.debug.print("num nonzero in sample: {}", jnp.count_nonzero(x_counts))
jax.debug.print("num nonzero in grad(lp): {}", jnp.count_nonzero(x_lp))

I believe this can be remedied by writing a custom JVP but I'm not 100% sure how that works if I am being completely honest. For now I will implement the gradient by hand, but I'd appreciate it if this could be worked on!

Thanks! and please let me know if there is something I am missing conceptually.

nmonette commented 2 months ago

A friend pointed out that this could be due to the re-normalization being added to the computational graph when the Multinomial distribution is initialized