credit to @IanTempleGH for pointing out this bug and the available trick.
The previous Gumbel-Softmax implementation missed a Straight-Through Gumbel Softmax implementation which allows the use of reparametrizable one-hot encoding. This trick is straightforward:
Let rout be the reparametrized output, so it is differentiable, and let out be the one-hot output obtained through argmax, so it is not differentiable. We wish to obtained a differentiable one-hot output:
(out - rout).detach() + rout
This way, we obtain the desired magnitude of one-hot, and the gradient is propagated through the right rout term, which is what we want.
The previous version simply used rout obtained directly from the rsample distribution. This means that when calculating the loss for SAC, the input to the Q-function was not one-hot like [1.0, 0.0] (hard input), but rather the underlying distribution, for example [0.9, 0.1] (soft input). This is a small bug and does not impact the algorithm's performance, as we will show below.
Experiment results
Even with this small concession, discrete SAC is able to train well. With this concession fixed, the result is comparable to the previous results. A comparison using LunarLander is given below, with the difference being only the git diff in this Pull Request.
before fix
after fix
git SHA ae9b82cdd9dd9a64e9be767ec1237a04e8276804
git SHA b8ef1737aa93661c8a45b6cce0dcb970a2a08262
The gradients are always propagated correctly since it will only use rout in both cases. The only effective difference is the magnitude of values input to the Q function, say Q(state, action=[0.9, 0.1]) vs Q(state, action=[1.0, 0.0]) in the before and after cases respectively. The Q function seems to be able to learn to work with both the soft and hard inputs to yield a similar output Q value, therefore learning is not impaired.
Conclusion
Overall, the results are similar, given a bit of variations (RL results tend to have high variance). We also observe no significant difference in the performance of SAC on Atari Pong, which will be published in the future. In conclusion, this small bug does not break things/results, but good to have fixed.
Straight-Through Gumbel Softmax Fix
The previous Gumbel-Softmax implementation missed a Straight-Through Gumbel Softmax implementation which allows the use of reparametrizable one-hot encoding. This trick is straightforward:
Let
rout
be the reparametrized output, so it is differentiable, and letout
be the one-hot output obtained through argmax, so it is not differentiable. We wish to obtained a differentiable one-hot output:This way, we obtain the desired magnitude of one-hot, and the gradient is propagated through the right
rout
term, which is what we want.The previous version simply used
rout
obtained directly from thersample
distribution. This means that when calculating the loss for SAC, the input to the Q-function was not one-hot like[1.0, 0.0]
(hard input), but rather the underlying distribution, for example[0.9, 0.1]
(soft input). This is a small bug and does not impact the algorithm's performance, as we will show below.Experiment results
Even with this small concession, discrete SAC is able to train well. With this concession fixed, the result is comparable to the previous results. A comparison using LunarLander is given below, with the difference being only the git diff in this Pull Request.
The gradients are always propagated correctly since it will only use
rout
in both cases. The only effective difference is the magnitude of values input to the Q function, sayQ(state, action=[0.9, 0.1])
vsQ(state, action=[1.0, 0.0])
in the before and after cases respectively. The Q function seems to be able to learn to work with both the soft and hard inputs to yield a similar output Q value, therefore learning is not impaired.Conclusion
Overall, the results are similar, given a bit of variations (RL results tend to have high variance). We also observe no significant difference in the performance of SAC on Atari Pong, which will be published in the future. In conclusion, this small bug does not break things/results, but good to have fixed.