Open RemiG3 opened 1 year ago
@araffin has already mentioned that he created it and will make it public (https://github.com/DLR-RM/stable-baselines3/issues/177#issuecomment-703268927).
I meant the SB3 contrib repo.
For GTrXL, are you willing to contribute that algorithm? Please read carefully the contributing guide if you decide to.
I meant the SB3 contrib repo.
Sorry for the misunderstanding.
For GTrXL, are you willing to contribute that algorithm?
I'm not sure yet, I will try to implement it for my experiments first.
Also related: https://github.com/maohangyu/TIT_open_source
@RemiG3 hey, have you started to implement it? Mayba I can give a free hand in it :)
Yes, I have implemented it, but not tested properly. I'm currently having some troubles with my custom environment that I'm trying to solve.
@araffin is it possible to create a new branch for this feature (to share the code)? If it is possible, I'll clean up the code and push it to this new branch soon.
Yes, I have implemented it, but not tested properly. I'm currently having some troubles with my custom environment that I'm trying to solve.
@araffin is it possible to create a new branch for this feature (to share the code)? If it is possible, I'll clean up the code and push it to this new branch soon.
yes, that's what a fork and pull request are meant for
I have came accross on this, this is quite modular and easy to tune, Transformers-RL, the only backside is that, it has been implemented only to gaussian policy.
Hey, I finally made the PR #176 to share the code.
It should work, but I'm not sure about the performances. It would be nice if someone could make comparisons with other methods (or RLlib attention net for example). I won't have time these next days.
RemiG3, Thank you for adding attention net to contrib. what's the shape of the input would be look like , for example if I want to use cartpole environment? Thanks again.
Thank you, @eric000888, for reporting this (feel free to provide the code you tested as you did in your first edits).
I have updated the branch to fix a bug on the dimension of minibatchs.
But, I still have an exception when batch_size = 1
or n_steps = 1
and I found the same exception for RecurrentPPO.
So, it should now work for batch_size > 1
and n_steps > 1
(as RecurrentPPO).
EDIT: I also add assertions about these cases, as in the original PPO.
RemiG3, Sorry for late response, here is my first post code:
from sb3_contrib.ppo_attention.ppo_attention import AttentionPPO from sb3_contrib.ppo_attention.policies import MlpAttnPolicy
VE = DummyVecEnv([lambda: gym.make("CartPole-v1")])
model = AttentionPPO(
"MlpAttnPolicy",
VE,
n_steps=240,
learning_rate=0.0003,
verbose=1,
batch_size=12,
ent_coef=0.03,
vf_coef=0.5,
seed=1,
n_epochs=10,
max_grad_norm=1,
gae_lambda=0.95,
gamma=0.99,
device='cpu',
policy_kwargs = dict(
net_arch=dict(pi=[64,32],vf=[64,32]),
)
)
First I create a vector environments and then setup the model like LSTM recurrent PPO, then run the model.learn(). I track the code and found the internal calculation return number is ok at the beginning but after a few loop it start return 'NA' and then stopped. I saw some other implementation use stacked frame and use sliding window as input format so I'm a little bit confused about what's should be the correct input format. But from your code I think the input should just one records at the time, don't need to stack the records.
I follow the code and saw you concatenate the tensor of input and memory, but the input format from SB3 is one records and then after the first round of full loop it's become batch number of records and that throw the error as the memory is still just one
tensor instead of batch.
Thank you for the update, i will try it this weekend.
another questions is if you just use GtrXL as feature extractor in PPO model, is this will get the same results? as the LSTM recurrent PPO has a flag to use the LSTM layer or not , similar like a feature extractor layer.
another thing is GtrXL demand more computation power , and PPO is like aiming a moving target, I found training a GtrXL PPO is a daunting task especially when using multiple layers. but if you can update the gradient on the whole trajectory then you may speed up the learning process. that means you collect all action/observation and then do one pass of back propagation.
🚀 Feature Request
This feature request is a duplicate from stable-baselines3 (see https://github.com/DLR-RM/stable-baselines3/issues/177).
The idea is to add the GTrXL model in the contrib repo from the paper Stabilizing Transformers for Reinforcement Learning, as done in RLlib: https://github.com/ray-project/ray/blob/master/rllib/models/torch/attention_net.py.
@araffin has already mentioned that he created it and will make it public (comment).I wonder if this is still relevant?