Closed GRIGORR closed 4 years ago
If we use torch.multinomial, how to BP the gradients from loss through multinomial to logits?
Your sampling is the following
gumbels = -torch.empty_like(xins).exponential_().log()
logits = (xins.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
I wrote it this way
logits = xins / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = torch.multinomial(probs, 1)
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
In both ways gradient function of hardwts is the gradient of probs, so sampling is used just for determining the index. I am not saying that sampling by Gumbel is wrong just its interesting whether my approach has some flow that I havent noticed or not.
Thanks for sharing the codes. It is interesting. From the view of engineering, they achieve the same goal, BP gradients and sampling. The gradients w.r.t. logits of these two strategies are different. Gumbel softmax has some theoretical supports, while I'm not sure about the theoretical guarantee of your strategy. It is possible that they can achieve similar empirical results.
OK thanks very much
You are welcome!
And one more question. In paper drop-path is used, but there is no usage of it in the code and is a TO-DO. Did it help and if yes how was implemented? Was it just randomly masking some layers of feature maps before concatenating? Thanks
Your sampling is the following
gumbels = -torch.empty_like(xins).exponential_().log() logits = (xins.log_softmax(dim=1) + gumbels) / self.tau probs = nn.functional.softmax(logits, dim=1) index = probs.max(-1, keepdim=True)[1] one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) hardwts = one_h - probs.detach() + probs
I wrote it this way
logits = xins / self.tau probs = nn.functional.softmax(logits, dim=1) index = torch.multinomial(probs, 1) one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) hardwts = one_h - probs.detach() + probs
In both ways gradient function of hardwts is the gradient of probs, so sampling is used just for determining the index. I am not saying that sampling by Gumbel is wrong just its interesting whether my approach has some flow that I havent noticed or not.
Have you tried the latter way and gotten the same or different results? I am also very interested about that.
Yes I tried the latter way and did not see any significant difference
Hi, a small question regarding sampling with Gumbel in GDAS. Can't we directly softmax the logits and sample by torch.multinomial()? I have done it and there is nearly no difference in the results. P.S. If I am not mistaken, torch.multinomial() just divides [0,1] into bins equivalent to probabilities, samples a number from [0,1] and returns in which bin the sampled number has fallen. Thanks.