Stable-Baselines-Team / stable-baselines3-contrib

Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code
https://sb3-contrib.readthedocs.io
MIT License
504 stars 175 forks source link

PPO attention net (GTrXLNet) #176

Open RemiG3 opened 1 year ago

RemiG3 commented 1 year ago

Description

Add PPO attention network (GTrXLNet, paper: Stabilizing Transformers for Reinforcement Learning). Comparisons have to be made (with the implementation of RLlib for example).

closes https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/165

Note: I have cleaned up most of the code, but it's still under development.

Context

Types of changes

Checklist:

Note: we are using a maximum length of 127 characters per line.

rrfaria commented 1 year ago

Hey @RemiG3

I hope everything is going well. 👋 I've been following the development of the attention PPO feature, and I'm really excited about the progress being made!

Could you provide an update on the current status of this feature? I'd love to know where it stands and if there's anything new to be excited about since the last time you commented.

I came across this example you shared:

from sb3_contrib.ppo_attention.ppo_attention import AttentionPPO
from sb3_contrib.ppo_attention.policies import MlpAttnPolicy

VE = DummyVecEnv([lambda: gym.make("CartPole-v1")])

model = AttentionPPO(
    "MlpAttnPolicy",
    VE,
    n_steps=240,
    learning_rate=0.0003,
    verbose=1,
    batch_size=12,
    ent_coef=0.03,
    vf_coef=0.5,
    seed=1,
    n_epochs=10,
    max_grad_norm=1,
    gae_lambda=0.95,
    gamma=0.99,
    device='cpu',
    policy_kwargs=dict(
        net_arch=dict(pi=[64, 32], vf=[64, 32]),
    )
)

Does it still work like this?

If there's any example available to better understand how this feature is being implemented or if it's already possible to test a prototype, I'd be incredibly grateful for any information in this regard.

Thank you so much for the hard work you're putting into this.

Many thanks, and I'm eagerly looking forward to your response. 🚀

LeZheng-x commented 4 months ago

In igibson, I compared the three algorithms PPO, Recurrent_PPO, and Attention_PPO. Unfortunately even if I try to change the network parameters of GTrXL, it works poorly and requires more training time.