danijar / dreamerv2

Mastering Atari with Discrete World Models
https://danijar.com/dreamerv2
MIT License
886 stars 195 forks source link

Straight-thru gradients vs Gumbel Softmax #37

Closed zplizzi closed 2 years ago

zplizzi commented 2 years ago

I'm curious if you considered trying the gumbel softmax as an alternative to the way you implemented straight-thru gradients in this paper/code. It seems like it might be a less-biased way of backpropagating through the operation of sampling from a categorical distribution. The "hard" variant allows you to retain a purely discrete one-hot output in the forward pass, as you did here.

As I understand it, you implemented:

And the (hard version of the) gumbel softmax is:

The forwards in both versions are equivalent - the second is just a reparameterization of the first. By altering the temperature hyperparameter, you can trade off bias and variance.

danijar commented 2 years ago

Hey, I've tried gumble softmax back then but found it required careful temperature tuning and even temperature annealing to match the performance of straight-through gradients, which are easier and free of hparams. It's possible that gumble can be made to work better though, I'm not sure.