Closed zplizzi closed 2 years ago
Hey, I've tried gumble softmax back then but found it required careful temperature tuning and even temperature annealing to match the performance of straight-through gradients, which are easier and free of hparams. It's possible that gumble can be made to work better though, I'm not sure.
I'm curious if you considered trying the gumbel softmax as an alternative to the way you implemented straight-thru gradients in this paper/code. It seems like it might be a less-biased way of backpropagating through the operation of sampling from a categorical distribution. The "hard" variant allows you to retain a purely discrete one-hot output in the forward pass, as you did here.
As I understand it, you implemented:
one_hot(draw(logits))
softmax(logits, temp=1)
And the (hard version of the) gumbel softmax is:
one_hot(arg_max(log_softmax(logits) + sample_from_gumbel_dist)
softmax(log_softmax(logits) + sample_from_gumbel_dist), temp=temp_hyperparam)
The forwards in both versions are equivalent - the second is just a reparameterization of the first. By altering the temperature hyperparameter, you can trade off bias and variance.