ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
32.95k stars 5.58k forks source link

[Bug] [rllib] RNN sequencing is incorrect #19976

Closed smorad closed 2 years ago

smorad commented 2 years ago

Search before asking

Ray Component

RLlib

What happened + What you expected to happen

I would expect given two sequences A, B: [A, A, A, B, B]; seq_lens=[3, 2], obs.shape = [5, 1] would be padded to [A, A, A, B, B, *]; seq_lens=[3, 2], obs.shape = [2, 3, 1]

This does not appear to be the case. For some reason rllib zero-pads obs to something besides seq_lens.max(). Even more worrisome is calling torch.nonzero() on the input_dict, which shows front-padded zeros to the observations. For example, printing input_dict['obs'].reshape(B, T, -1) == 0 results in:

(PPO pid=74137)         [[ True,  True],
(PPO pid=74137)          [ True,  True],
(PPO pid=74137)          [ True,  True],
(PPO pid=74137)          [ True,  True],
(PPO pid=74137)          [ True,  True],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [ True,  True],
(PPO pid=74137)          [ True,  True],
(PPO pid=74137)          [ True,  True]],

The zero-padding is clearly messed up, the first five observations have been zero-padded and then we have real observations offset by five.

Versions / Dependencies

Linux Ray 1.7.0

Reproduction script

Feel free to play with the USE_CORRECT_SHAPE flag

import torch
import numpy as np
import gym
from typing import Union, Dict, List, Tuple, Any
import ray
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.typing import ModelConfigDict, TensorType
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.tune import register_env
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole

# Pad to the correct size and crash
# or follow the rnn_sequencing code and don't crash
USE_CORRECT_SHAPE = False

class TestRNN(TorchModelV2, torch.nn.Module):
    def __init__(
        self,
        obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        num_outputs: int,
        model_config: ModelConfigDict,
        name: str,
        **custom_model_kwargs,
    ):
        TorchModelV2.__init__(
            self, obs_space, action_space, num_outputs, model_config, name
        )
        torch.nn.Module.__init__(self)
        self.num_outputs = num_outputs
        self.input_dim = gym.spaces.utils.flatdim(obs_space)
        self.act_space = action_space
        self.act_dim = gym.spaces.utils.flatdim(action_space)
        self.cur_val = None

        self.policy = torch.nn.Linear(self.input_dim, self.act_dim)
        self.vf = torch.nn.Linear(self.input_dim, 1)

    def get_initial_state(self):
        return [torch.zeros(0)]

    def value_function(self):
        assert self.cur_val is not None, "must call forward() first"
        return self.cur_val

    def forward(
        self,
        input_dict: Dict[str, TensorType],
        state: List[TensorType],
        seq_lens: TensorType,
    ) -> Tuple[TensorType, List[TensorType]]:

        flat = input_dict["obs_flat"]

        if USE_CORRECT_SHAPE:
            max_seq_len = seq_lens.max()
        else:
            # max_seq_len here is copied from rllib RNN code
            # see https://github.com/ray-project/ray/blob/2d24ef0d3234867ac329b10ae3a11b9b7119d17b/rllib/models/torch/recurrent_net.py#L75
            # but it doesn't make sense...
            # it should be max_seq_len = seq_len.max()
            max_seq_len = flat.shape[0] // seq_lens.shape[0]

        padded = add_time_dimension(
            flat,
            max_seq_len=max_seq_len,
            framework="torch",
            time_major=False
        )

        B = padded.shape[0]
        T = padded.shape[1]

        # If this fails, then we have "extra" padding in the RNN
        # We shouldn't need to pad the time dimension more than the longest
        # sequence
        if seq_lens.max() != T:
            print(f'seq_lens.max() is {seq_lens.max()} but input temporal dim is {T}')
            print(flat.reshape(B, T, -1) == 0)
            raise Exception('seq_len mismatch')

        flattened = padded.reshape(-1, padded.shape[-1])
        logits = self.policy(flattened)
        self.cur_val = self.vf(flattened).squeeze(1)
        state = state

        return logits, state

register_env(StatelessCartPole.__name__, StatelessCartPole)
MAX_SEQ_LEN = 200
CFG = {
    "env_config": {},
    "framework": "torch",
    "model": {
        "custom_model": TestRNN,
        "max_seq_len": MAX_SEQ_LEN,
    },
    "num_workers": 0,
    "num_gpus": 0,
    "env": StatelessCartPole,
    "horizon": MAX_SEQ_LEN,
}
ray.init(object_store_memory=3e10)
analysis = ray.tune.run(
    PPOTrainer,
    config=CFG,
)

Anything else

Every train step

Are you willing to submit a PR?

smorad commented 2 years ago

I just checked and this is not an issue with A2C. Maybe it is just PPO-related?

gjoliver commented 2 years ago

input_dict["obs_flat"] is already padded. so shouldn't flat.shape[0] // seq_lens.shape[0] basically be seq_lens.max() assuming the padding is correct? let me give your script a try.

sven1977 commented 2 years ago

Hey @smorad , when I run your example as-is, I get a train_batch with batch dim=0 inside the PPO loss function. Something is definitely not right here, but I'm not sure it's a generic RLlib problem.

smorad commented 2 years ago

@gjoliver You're probably right in general. But in my example, if you set USE_CORRECT_SHAPE=True you'll see that the two values are not the same.

@sven1977 If you set USE_CORRECT_SHAPE to True, you should see a crash and hopefully the issue will be more clear. I'm using ray-1.7.0 which might differ from master.

mvindiola1 commented 2 years ago

Hi @sven1977, @gjoliver and @smorad,

I was able to reproduce both Sven's and Steven's errors in master.

The error, or at least part of it seems to arise as follows:

simple_list_collector, is creating SampleBatches with the max_seq_len set to the configuration value: https://github.com/ray-project/ray/blob/9dba5e0eadd3a065023d6cc7cafff631355c980a/rllib/evaluation/collectors/simple_list_collector.py#L329-L337

The sample_batch during training uses this value to construct indexes into the batch for training https://github.com/ray-project/ray/blob/9dba5e0eadd3a065023d6cc7cafff631355c980a/rllib/policy/sample_batch.py#L871-L874

But if rnn_sequencing:pad_batch_to_sequences_of_same_size sets dynamic_max to True, which it is doing in this repro script, then the indexes will be off when none of the rollout episodes last max_seq_len (200) steps. Hint: at the beginning of training none of them do.

https://github.com/ray-project/ray/blob/9dba5e0eadd3a065023d6cc7cafff631355c980a/rllib/policy/rnn_sequencing.py#L265-L268

This leads to one of two error conditions:

  1. If the adjusted start and stop times are within the bounds of the SampleBatch then you will get a batch but it will likely not be aligned to the actual start of the batch because the max_seq_len is too large (@smorad's error).

  2. If the adjusted start or stop of the minibatch exceed the bounds of the SampleBatch then you will get an empty SampleBatch (@sven1977's error) .

mvindiola1 commented 2 years ago

P.S @smorad If you are using 1 or fewer gpus then adding ""simple_optimizer": True" to the config seems to avoid this issue.

gjoliver commented 2 years ago

thanks @mvindiola1, I can reproduce these pretty reliably too. I guess the start or stop indices shouldn't be conditioned on max_seq_len right? I need to read SampleBatch a bit more, but we should clean this up.

mvindiola1 commented 2 years ago

@gjoliver

I changed this:

https://github.com/ray-project/ray/blob/9dba5e0eadd3a065023d6cc7cafff631355c980a/rllib/policy/rnn_sequencing.py#L116-L135

To the snippet below and I think it fixes the misalignment issue. I did not see it anymore in my quick test.

    feature_sequences, initial_states, seq_lens = \
        chop_into_sequences(
            feature_columns=[batch[k] for k in feature_keys_],
            state_columns=[batch[k] for k in state_keys],
            episode_ids=batch.get(SampleBatch.EPS_ID),
            unroll_ids=batch.get(SampleBatch.UNROLL_ID),
            agent_indices=batch.get(SampleBatch.AGENT_INDEX),
            seq_lens=batch.get(SampleBatch.SEQ_LENS),
            max_seq_len=max_seq_len,
            dynamic_max=dynamic_max,
            states_already_reduced_to_init=states_already_reduced_to_init,
            shuffle=shuffle)

    for i, k in enumerate(feature_keys_):
        batch[k] = feature_sequences[i]
    for i, k in enumerate(state_keys):
        batch[k] = initial_states[i]
    batch[SampleBatch.SEQ_LENS] = np.array(seq_lens)

    if dynamic_max:
        batch.max_seq_len = max(seq_lens)

As for smorad's original issue where the max of seq_lens does not match the padded seq len, you will still see that is the case. The reason is as follows. The padding is done according to the longest sequence on the entire SampleBatch collected during rollouts. However the forward function in the model is given sub-sequences based on sgd_minibatch_size. So if the max sequence length is lets say 90. It is possible that a sub-sequence from that randomly sampled mini-batch will have a max sequence that is shorter. [20,32,15,25,90,45,15]. If we sample two contiguous sequences from that list then we have 2 ways of drawing a minibatch with 90 and 4 ways of drawing a minibatch without 90 which is what the repo script is detecting.

smorad commented 2 years ago

@mvindiola1 with respect to simple_optimizer: If I use the following config, I get intermittent crashes at around 50 iters:

CFG = {
    "env_config": {},
    "framework": "torch",
    "model": {
        "custom_model": TestRNN,
        "max_seq_len": MAX_SEQ_LEN,
    },
    "num_workers": 0,
    "num_gpus": 0,
    "env": StatelessCartPole,
    "horizon": MAX_SEQ_LEN,
    "simple_optimizer": True,
}

Is this the second error condition?

2021-11-06 13:33:58,307 ERROR trial_runner.py:924 -- Trial PPO_StatelessCartPole_71918_00000: Error processing event.
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/ray/tune/trial_runner.py", line 890, in _process_trial
    results = self.trial_executor.fetch_result(trial)
  File "/usr/local/lib/python3.8/dist-packages/ray/tune/ray_trial_executor.py", line 788, in fetch_result
    result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
  File "/usr/local/lib/python3.8/dist-packages/ray/_private/client_mode_hook.py", line 105, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/ray/worker.py", line 1625, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AssertionError): ray::PPO.train_buffered() (pid=45224, ip=172.17.0.2, repr=PPO)
  File "/usr/local/lib/python3.8/dist-packages/ray/tune/trainable.py", line 224, in train_buffered
    result = self.train()
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/agents/trainer.py", line 682, in train
    raise e
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/agents/trainer.py", line 668, in train
    result = Trainable.train(self)
  File "/usr/local/lib/python3.8/dist-packages/ray/tune/trainable.py", line 283, in train
    result = self.step()
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/agents/trainer_template.py", line 206, in step
    step_results = next(self.train_exec_impl)
  File "/usr/local/lib/python3.8/dist-packages/ray/util/iter.py", line 756, in __next__
    return next(self.built_iterator)
  File "/usr/local/lib/python3.8/dist-packages/ray/util/iter.py", line 783, in apply_foreach
    for item in it:
  File "/usr/local/lib/python3.8/dist-packages/ray/util/iter.py", line 783, in apply_foreach
    for item in it:
  File "/usr/local/lib/python3.8/dist-packages/ray/util/iter.py", line 843, in apply_filter
    for item in it:
  File "/usr/local/lib/python3.8/dist-packages/ray/util/iter.py", line 843, in apply_filter
    for item in it:
  File "/usr/local/lib/python3.8/dist-packages/ray/util/iter.py", line 783, in apply_foreach
    for item in it:
  File "/usr/local/lib/python3.8/dist-packages/ray/util/iter.py", line 783, in apply_foreach
    for item in it:
  File "/usr/local/lib/python3.8/dist-packages/ray/util/iter.py", line 791, in apply_foreach
    result = fn(item)
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/execution/train_ops.py", line 64, in __call__
    learner_info = do_minibatch_sgd(
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/utils/sgd.py", line 104, in do_minibatch_sgd
    for minibatch in minibatches(batch, sgd_minibatch_size):
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/utils/sgd.py", line 53, in minibatches
    all_slices = samples._get_slice_indices(sgd_minibatch_size)
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/utils/annotations.py", line 101, in _ctor
    return obj(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/policy/sample_batch.py", line 904, in _get_slice_indices
    assert np.all(self[SampleBatch.SEQ_LENS] < slice_size), \
AssertionError: ERROR: `slice_size` must be larger than the max. seq-len in the batch!
mvindiola1 commented 2 years ago

@smorad

That is unfortunate. I only ran for 10 iterations in my test.

Did you try adding this snippet mentioned above with the standard optimizer?

if dynamic_max:
        batch.max_seq_len = max(seq_lens)
smorad commented 2 years ago

@mvindiola1 adding that line doesn't seem to stop it from crashing. This time it crashed at iter 88.

gjoliver commented 2 years ago

ok, seems there are 2 bug here.

  1. chop_into_sequences doesn't work when dynamic_max is True. I think I need to repro with a unit test, and just fix the behavior.
  2. the assert in _get_slice_indices(), which you guys hit. I feel like it should be <= in stead of <? The reason Steven only hits that crash towards later iteration is that the policy only gets good enough to run to the full max_seq_len toward later part of training. I think one test to quickly verify this theory is to set MAX_SEQ_LEN to 250 in your test, assuming the env only outputs up to 200 time steps.

what do you guys think?

mvindiola1 commented 2 years ago

@gjoliver,

I thin l you are right that it should be <= but I dont think the <= is going to fix that error. I just eyeballed the code so I could be wrong but if you look at where it is called:

https://github.com/ray-project/ray/blob/ac3371a148385027cf034e152d50f528acb853de/rllib/utils/sgd.py#L53

It is checking that the seq lengths are shorter than the sgd_minibatch_size (128) but in this example the horizon and max_seq_len is 200. When the policy gets good enough it will generate sequences longer than 127 timesteps and trigger the exception.

@smorad you could try running with a sgd_minibatch_size > 200 and see if the error goes away.

gjoliver commented 2 years ago

hmm, yeah, but this doesn't make any sense. why should sgd_minibatch_size has anything to do with how long the seqs are from the environment??

mvindiola1 commented 2 years ago

@gjoliver I think the max_seq_len th is to truncate BPTT but if that seq len is larger than the Mini-batch size you could not actually backprop through the seq length the user requested.

My thought on the solution here is to put a check in validate config that if the policy is recurrent then make sure the mini-batch size is >= to the max_seq_len.

smorad commented 2 years ago

@mvindiola1 that sounds like a good solution. It still allows for long or infinite-length episodes while chunking trajectories into manageable sizes.

mvindiola1 commented 2 years ago

@gjoliver @sven1977,

What is the status on this? I think this should get fixed before 1.9 is cut. I am happy to start working on a PR if it is not already in progress.

mvindiola1 commented 2 years ago

I went ahead an prepared a PR just in case.

smorad commented 2 years ago

For those reading now, you need to do the following in your config for PPO to work correctly with recurrent models:

max_seq_len = some_value

config = {
  "use_simple_optimizer": True, 
  "horizon": max_seq_len - 1,
  "model": {
    "max_seq_len": max_seq_len
  }
}
mvindiola1 commented 2 years ago

@smorad,

Did the way horizon works change? My understanding from looking at it in the past was that horizon would terminate the episode. What if the max_seq_len is shorter than the episode length?

DarrellDai commented 1 year ago

For those reading now, you need to do the following in your config for PPO to work correctly with recurrent models:

max_seq_len = some_value

config = {
  "use_simple_optimizer": True, 
  "horizon": max_seq_len - 1,
  "model": {
    "max_seq_len": max_seq_len
  }
}

@smorad I tried this for PPO, but it still doesn't work for me. I did some test, and found out that if you set sgd_minibatch_size>=num_gpus*max_seq_len, it would work.

config = {
  "sgd_minibatch_size": sgd_minibatch_size
  "model": {
    "use_lstm": true,
    "max_seq_len": max_seq_len
  }
}
amadou1998 commented 1 year ago

For those reading now, you need to do the following in your config for PPO to work correctly with recurrent models:

max_seq_len = some_value

config = {
  "use_simple_optimizer": True, 
  "horizon": max_seq_len - 1,
  "model": {
    "max_seq_len": max_seq_len
  }
}

@smorad I tried this for PPO, but it still doesn't work for me. I did some test, and found out that if you set sgd_minibatch_size>=num_gpus*max_seq_len, it would work.

config = {
  "sgd_minibatch_size": sgd_minibatch_size
  "model": {
    "use_lstm": true,
    "max_seq_len": max_seq_len
  }
}

Hello, is this also true for attention models? I seam to get the same error. I would like to use a shorter sequence length, should it be a divisor of horizon +1?

man2machine commented 1 month ago

I am getting the same error AssertionError: ERROR: slice_size must be larger than the max. seq-len in the batch! when I use batch_mode="complete_episodes" with PPO multi-agent training. This is even whensgd_minibatch_size==num_gpus*max_seq_len`.

This is with ray 2.32.0. There is no horizon parameter anymore to set to make PPO work correctly with recurrent models. Can anyone help with what config should be done with the new versions of ray?