quantumiracle / Popular-RL-Algorithms

PyTorch implementation of Soft Actor-Critic (SAC), Twin Delayed DDPG (TD3), Actor-Critic (AC/A2C), Proximal Policy Optimization (PPO), QT-Opt, PointNet..
Apache License 2.0
1.09k stars 123 forks source link

Variable length episodes #20

Closed alanmackey closed 3 years ago

alanmackey commented 3 years ago

I tried to use you code on a custom environment that can have vaiable length episodes . Maybe I have set it up wrong but I can't figure out how it can work. The replay buffer is filled from complete episodes (not just state transitions) but in update in the td3_lstm it samples those episodes and uses torch.FloatTensor(state).to(device) to put it to the GPU but this wont work as the batch can have varying length episodes and pytorch wont allow this. Possibly works on a batch size of 1

quantumiracle commented 3 years ago

Hi, I think what you said is correct. For lstm version, the whole episode of transitions are sent to train the policy as a single sample. Variate length of episodes can be achieved with two ways: 1. padding consecutive transitions with zeros in episode if it has a different length from other episodes (the maximal one), however, the zero padding may not make sense in some cases; 2. change the update manner of lstm policy to take one transition at a time but keep the gradients of hidden states not detached, so that it can be tracked along the episode. Also welcome to contribute some other implementations for achieving that.