hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.13k stars 724 forks source link

Adding Additional Observations and Actions to Buffer per TimeStep #834

Open lukepolson opened 4 years ago

lukepolson commented 4 years ago

I'm currently training an SnakeGame environment where the observation space is square (nxn) and at any point the snake can move up, down, right, or left. Now for each step, I can generate an additional (state, action) pair by rotating the board by 90, 180, and 270 degrees, and also by mirroring it horizontally and vertically (and choosing the corresponding action of course). For example, if in the regular state the snake moved up, another equivalent valid state action pair is if the board was mirrored over a horizontal line in its center, but the snake went downwards. (For anyone interested in the symmetries of the square see https://proofwiki.org/wiki/Definition:Symmetry_Group_of_Square). Note that there are 8 different (unique) combinations of rotating and flipping a square, so this generates 8 (state, action) pairs per 1 time step (and 80 hours vs, 10 hours of training is quite a nice reduction in time).

What's the best way to implement adding these addition state action pairs to the buffer? I was considering using some sort of modified VecEnv, but perhaps someone can point me to a more straightforward approach.

Miffyli commented 4 years ago

I do not think there is other viable option than to modify stable-baselines code a bit. You could modify e.g. this line in DQN to do all the augmentations and put them to replay buffer. That should be enough, albeit bit dirty.

araffin commented 4 years ago

Hello, It sounds like a callback should be the right solution (once #787 is merged) as you have access to self.model, you can call self.model.replay_buffer.add() inside the callback. In fact, you don't even need #787 as you can retrieve the last entry in the replay buffer (a little bit dirty but should work). Otherwise, you can create a replay wrapper as it done for HER (and passed to the learn() method, maybe the cleanest way).

lukepolson commented 4 years ago

Cheers, Once 787 is merged I will give it a try and close this issue. For now I'll attempt to use the replay wrapper.