karpathy / deep-vector-quantization

VQVAEs, GumbelSoftmaxes and friends
MIT License
519 stars 43 forks source link

About the input of F.gumbel_softmax #7

Open ZikangZhou opened 3 years ago

ZikangZhou commented 3 years ago

From my understanding, the input of F.gumbel_softmax (i.e., thelogits parameter) should be the \log of a discrete distribution. However, I didn't see any softmax or log_softmax before the gumbel_softmax. It seems like you're treating the output of self.proj as log-probabilities with the range of (-inf, inf), which indicates that the probabilities of the discrete distribution have the range of (0, inf).

I'm curious about why you don't use softmax to normalize things into (0, 1) and make the sum of them to be 1. Does the mathematics still make sense without normalizing?

function2-llx commented 1 year ago

@ZikangZhou Hi, although this is an old thread, I want to share my thoughts here since there's still an open issue pointing here, I hope this doesn't cause any inconvenience.

According to the documentation of F.gumbel_softmax, the logits parameter represents "unnormalized log probabilities". The term "unnormalized" here likely indicates that the logits have not been adjusted to fall within a normalized range by uniformly shifting all components (a normalization example with logsumexp). This normalization step doesn't impact the results of softmax, as any uniform shift gets canceled out during the softmax calculation, according to its definition.(e.g., if we add $a$ to logits $(x, y)$, the softmax results will be the same: $\frac{e^{x+a}}{e^{x+a}+e^{y+a}} = \frac{e^x}{e^x+e^y}$). Therefore, I believe the authors' usage of F.gumbel_softmax is actually appropriate.

You may also check the PyTorch's implementation of F.gumbel_softmax to confirm.