ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
34.04k stars 5.78k forks source link

[Bug] [rllib] Attention and FrameStackingModel work poorly #20827

Closed stefanbschneider closed 2 years ago

stefanbschneider commented 2 years ago

Search before asking

Ray Component

RLlib

What happened + What you expected to happen

I've been experimenting with Ray RLlib's StatelessCartPole environment, where some observations are hidden, and with different options for how to deal with these partial observations. I noticed two problems:

See reproduction scripts for details.

There's also a discussion on discourse: https://discuss.ray.io/t/lstm-and-attention-on-stateless-cartpole/4293/3

Here's my blog post where with a notebook and details: https://stefanbschneider.github.io/blog/rl-partial-observability

Versions / Dependencies

Python 3.8 on Windows 10 Gym 0.21.0 (Had errors with frame stacking on lower versions!) Ray 2.0.0.dev0: Latest wheels from Dec 1, 2021; commit on master 0467bc9

Reproduction script

Details for reproducibility: I'm experimenting with PPO, the default config, 10 training iterations, and 3 CPUs

Default PPO on StatelessCartPole: Reward 51

By default (no frame stacking, no attention), PPO does not work so well on StatelessCartPole - as expected. For me, it leads to a reward of 51 after 10 train iterations; code below:

import ray
from ray import tune
from ray.rllib.agents import ppo
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.tune import registry

registry.register_env("StatelessCartPole", lambda _: StatelessCartPole())

ray.init(num_cpus=3)

config = ppo.DEFAULT_CONFIG
config["env"] = "StatelessCartPole"

stop = {"training_iteration": 10}
results = tune.run("PPO", config=config, stop=stop)

PPO + Attention on StatelessCartPole: Reward 39 -> Bug?

Now, the same thing with attention enabled and otherwise default params. I expected a much higher reward than without attention, but it's almost the same - even slightly worse! Why?

import ray
from ray import tune
from ray.rllib.agents import ppo
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.tune import registry

registry.register_env("StatelessCartPole", lambda _: StatelessCartPole())

ray.init(num_cpus=3)

config = ppo.DEFAULT_CONFIG
config["env"] = "StatelessCartPole"
config["model"] = {
    # attention
    "_use_default_native_models": True,
    "use_attention": True,
}

stop = {"training_iteration": 10}
results = tune.run("PPO", config=config, stop=stop)

Even adding extra model params from the attention net example doesn't help:

# extra model config from attention_net.py example
    "max_seq_len": 10,
    "attention_num_transformer_units": 1,
    "attention_dim": 32,
    "attention_memory_inference": 10,
    "attention_memory_training": 10,
    "attention_num_heads": 1,
    "attention_head_dim": 32,
    "attention_position_wise_mlp_dim": 32,

Stacked Frames in Env: Reward 202

Now, stacking frames in the environment with gym.wrappers.FrameStack works really well - without attention. For me, the reward is at 202 after 10 train iters (with 10 stacked frames). Code is identical except for creating a StackedStatelessCartPole env with stacked frames:

from gym.wrappers import FrameStack

import ray
from ray import tune
from ray.rllib.agents import ppo
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.tune import registry

NUM_FRAMES = 10

registry.register_env("StatelessCartPole", lambda _: StatelessCartPole())
registry.register_env("StackedStatelessCartPole",
                      lambda _: FrameStack(StatelessCartPole(), NUM_FRAMES))

ray.init(num_cpus=3)

config = ppo.DEFAULT_CONFIG
config["env"] = "StackedStatelessCartPole"

stop = {"training_iteration": 10}
results = tune.run("PPO", config=config, stop=stop)

Stacked Frames in Model: Reward 105 --> Bug?

Now, the same thing - stacking 10 frames - within the model rather than the environment. Surprisingly, this leads to a reward that's just half as high: 95 after 10 iters: Why? Isn't this supposed to do the same thing?

import ray
from ray import tune
from ray.rllib.agents import ppo
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.tune import registry
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.examples.models.trajectory_view_utilizing_models import FrameStackingCartPoleModel

NUM_FRAMES = 10

registry.register_env("StatelessCartPole", lambda _: StatelessCartPole())

ModelCatalog.register_custom_model("stacking_model", FrameStackingCartPoleModel)

ray.init(num_cpus=3)

config = ppo.DEFAULT_CONFIG
config["env"] = "StatelessCartPole"
config["model"] = {
     "custom_model": "stacking_model",
     "custom_model_config": {
         "num_frames": NUM_FRAMES,
     },
}

stop = {"training_iteration": 10}
results = tune.run("PPO", config=config, stop=stop)

TL;DR: Why does attention not help? Why does frame stacking within the model lead to much worse results than within the env?

Anything else

I think these are really two issues (or misunderstandings/misconfigurations from my side):

Maybe this is just an issue with the default configuration in this scenario...

I'm also willing to submit a PR, but am not sure where to start looking. I'll post any new insights in the comments.

Are you willing to submit a PR?

stefanbschneider commented 2 years ago

Or should I be stacking frames inside the model differently as mentioned here? https://github.com/ray-project/ray/issues/10882

mickelliu commented 2 years ago

Hi @stefanbschneider, any update on this? Thanks.

stefanbschneider commented 2 years ago

@mickelliu Unfortunately no. I have not heard back from anyone and haven't had time to look into this myself again. I hope to find some time and work on this again soon though.

I'll let you know when if have any updates.

mickelliu commented 2 years ago

@mickelliu Unfortunately no. I have not heard back from anyone and haven't had time to look into this myself again. I hope to find some time and work on this again soon though.

I'll let you know when if have any updates.

I got a similar situation where frame-stacking with ViewAPI deteriorates the performance of PPO models, but I haven't tried out if using the Gym's Framestack Wrapper will help. LSTM and attention also work worse (without frame stacks) in our custom environment.

stefanbschneider commented 2 years ago

Hm, wanted to test this with the latest Ray version, but now I can't even init Ray... https://github.com/ray-project/ray/issues/19835

sven1977 commented 2 years ago

Hey @stefanbschneider and @mickelliu , I'll take a look at this today w/ high prio. ...

mickelliu commented 2 years ago

Thanks @sven1977, it could be because that attention net and lstm are notoriously harder to train. Especially in our custom CV environment, we found that the LSTM stagnant performance problem was mainly attributed to the shared layers between value and actor networks, separating the two branches help the convergence of LSTM models... And as for GTrXL the original paper used 12 transformer layers and trained billions of timesteps to make it work in their environments...

sven1977 commented 2 years ago

On the attention-net experiment:

This following script learns pretty well, using attention (300k ts for reward ~190.0):

import ray
from ray import tune
from ray.rllib.agents import ppo
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.tune import registry

registry.register_env("StatelessCartPole", lambda _: StatelessCartPole())

ray.init(num_cpus=9)

config = ppo.DEFAULT_CONFIG
config["env"] = "StatelessCartPole"
config["observation_filter"] = "MeanStdFilter"
config["num_sgd_iter"] = 6
config["vf_loss_coeff"] = 0.0001    # <- IMPORTANT to tune this right such that vf_loss and policy_loss are roughly of the same magnitude
config["num_workers"] = 8
config["model"] = {
    "_use_default_native_models": True,
    "fcnet_hiddens": [32],
    "fcnet_activation": "linear",
    "vf_share_layers": True,
    "use_attention": True,
    "max_seq_len": 20,
    "attention_num_transformer_units": 1,
    "attention_dim": 32,
    "attention_memory_inference": 20,
    "attention_memory_training": 20,
    "attention_num_heads": 1,
    "attention_head_dim": 32,
    "attention_position_wise_mlp_dim": 32,
}
stop = {
    "training_iteration": 100,
}

results = tune.run("PPO", config=config, stop=stop)
sven1977 commented 2 years ago

On the framestacking experiment:

This script here works perfectly well (80k ts reaching reward ~200.0), using framestacking via the trajectory view API inside a custom model:

import ray
from ray import tune
from ray.rllib.agents import ppo
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.tune import registry
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.examples.models.trajectory_view_utilizing_models import FrameStackingCartPoleModel

NUM_FRAMES = 10

registry.register_env("StatelessCartPole", lambda _: StatelessCartPole())

ModelCatalog.register_custom_model("stacking_model", FrameStackingCartPoleModel)

ray.init(num_cpus=3)

config = ppo.DEFAULT_CONFIG
config["env"] = "StatelessCartPole"
config["vf_loss_coeff"] = 0.0001  # <- IMPORTANT to tune this right such that vf_loss and policy_loss are roughly of the same magnitude
config["model"] = {
     "custom_model": "stacking_model",
     "custom_model_config": {
         "num_frames": NUM_FRAMES,
     },
}

stop = {"training_iteration": 100}
results = tune.run("PPO", config=config, stop=stop)
sven1977 commented 2 years ago

I'm closing this issue. Feel free to re-open should you think that the above answers are not sufficient :)

@stefanbschneider @mickelliu