D-X-Y / AutoDL-Projects

Automated deep learning algorithms implemented in PyTorch.
MIT License
1.57k stars 282 forks source link

Question about gumbel-sampling. #74

Closed GRIGORR closed 4 years ago

GRIGORR commented 4 years ago

Hi, a small question regarding sampling with Gumbel in GDAS. Can't we directly softmax the logits and sample by torch.multinomial()? I have done it and there is nearly no difference in the results. P.S. If I am not mistaken, torch.multinomial() just divides [0,1] into bins equivalent to probabilities, samples a number from [0,1] and returns in which bin the sampled number has fallen. Thanks.

D-X-Y commented 4 years ago

If we use torch.multinomial, how to BP the gradients from loss through multinomial to logits?

GRIGORR commented 4 years ago

Your sampling is the following

gumbels = -torch.empty_like(xins).exponential_().log()
logits = (xins.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs

I wrote it this way

logits = xins / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = torch.multinomial(probs, 1)
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs

In both ways gradient function of hardwts is the gradient of probs, so sampling is used just for determining the index. I am not saying that sampling by Gumbel is wrong just its interesting whether my approach has some flow that I havent noticed or not.

D-X-Y commented 4 years ago

Thanks for sharing the codes. It is interesting. From the view of engineering, they achieve the same goal, BP gradients and sampling. The gradients w.r.t. logits of these two strategies are different. Gumbel softmax has some theoretical supports, while I'm not sure about the theoretical guarantee of your strategy. It is possible that they can achieve similar empirical results.

GRIGORR commented 4 years ago

OK thanks very much

D-X-Y commented 4 years ago

You are welcome!

GRIGORR commented 4 years ago

And one more question. In paper drop-path is used, but there is no usage of it in the code and is a TO-DO. Did it help and if yes how was implemented? Was it just randomly masking some layers of feature maps before concatenating? Thanks

sbl1996 commented 4 years ago

Your sampling is the following

gumbels = -torch.empty_like(xins).exponential_().log()
logits = (xins.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs

I wrote it this way

logits = xins / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = torch.multinomial(probs, 1)
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs

In both ways gradient function of hardwts is the gradient of probs, so sampling is used just for determining the index. I am not saying that sampling by Gumbel is wrong just its interesting whether my approach has some flow that I havent noticed or not.

Have you tried the latter way and gotten the same or different results? I am also very interested about that.

GRIGORR commented 4 years ago

Yes I tried the latter way and did not see any significant difference