Closed peterzpy closed 4 years ago
你好,感谢提问。代码对应部分在util/gumbel.py下的Line33-Line43。由于gumbel softmax有随机变量的存在,反向传播时我们只根据响应最高的点回传梯度
首先感谢您的回答。之前看论文的时候主要就是这个正向传播和梯度反传时的不一致性的地方不是特别理解,您论文中不是提到回传梯度时会考虑全部的像素点吗?我看您的代码中好像只使用了topk生成mask的方法,这样的话正向传播与反向传播是一致的,这里是您论文中的表述有问题,还是我理解错了呢?
是这样的,前向的时候挑选的点是经由gumbel softmax采样而来的,gumbel softmax是个采样手段(包含随机变量),所以反向传播的时候我们会从最高响应的topk个点(即lambda,注意lambdai,j和pi,j并不一致,见文中公式2)中回传回去(也就是论文中提到的KeepTopK),所以正向和反向实际是不一样的
您的论文中提到的对于mask在正向传播与反向传播时的处理,我好像没有在代码中找到对应的部分。