kengz / SLM-Lab

Modular Deep Reinforcement Learning framework in PyTorch. Companion library of the book "Foundations of Deep Reinforcement Learning".
https://slm-lab.gitbook.io/slm-lab/
MIT License
1.24k stars 264 forks source link

use Straight-Through Gumbel Softmax #422

Closed kengz closed 4 years ago

kengz commented 4 years ago

Straight-Through Gumbel Softmax Fix

credit to @IanTempleGH for pointing out this bug and the available trick.

The previous Gumbel-Softmax implementation missed a Straight-Through Gumbel Softmax implementation which allows the use of reparametrizable one-hot encoding. This trick is straightforward:

Let rout be the reparametrized output, so it is differentiable, and let out be the one-hot output obtained through argmax, so it is not differentiable. We wish to obtained a differentiable one-hot output:

(out - rout).detach() + rout

This way, we obtain the desired magnitude of one-hot, and the gradient is propagated through the right rout term, which is what we want.

The previous version simply used rout obtained directly from the rsample distribution. This means that when calculating the loss for SAC, the input to the Q-function was not one-hot like [1.0, 0.0] (hard input), but rather the underlying distribution, for example [0.9, 0.1] (soft input). This is a small bug and does not impact the algorithm's performance, as we will show below.

Experiment results

Even with this small concession, discrete SAC is able to train well. With this concession fixed, the result is comparable to the previous results. A comparison using LunarLander is given below, with the difference being only the git diff in this Pull Request.

before fix after fix
sac_lunar_t0_trial_graph_mean_returns_ma_vs_frames sac_lunar_t0_trial_graph_mean_returns_ma_vs_frames
git SHA ae9b82cdd9dd9a64e9be767ec1237a04e8276804 git SHA b8ef1737aa93661c8a45b6cce0dcb970a2a08262

The gradients are always propagated correctly since it will only use rout in both cases. The only effective difference is the magnitude of values input to the Q function, say Q(state, action=[0.9, 0.1]) vs Q(state, action=[1.0, 0.0]) in the before and after cases respectively. The Q function seems to be able to learn to work with both the soft and hard inputs to yield a similar output Q value, therefore learning is not impaired.

Conclusion

Overall, the results are similar, given a bit of variations (RL results tend to have high variance). We also observe no significant difference in the performance of SAC on Atari Pong, which will be published in the future. In conclusion, this small bug does not break things/results, but good to have fixed.