ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.34k stars 5.64k forks source link

RLLib: Pytorch + PPO + RNN: KeyError: 'seq_lens' in the batch dictionary #7206

Closed ludns closed 4 years ago

ludns commented 4 years ago

What is the problem?

Installed ray with the nightly wheel. I wrote a custom env, model, and action distribution. I attempt to train it with PPO but there is a key error in one of the internal object used by RLLib (the batch dict with "seq_lens" that is used for masking recurrent model when backpropagating) 2020-02-18 16:11:55,261 ERROR trial_runner.py:513 -- Trial PPO_test_49f2c33a: Error processing event. Traceback (most recent call last): File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/tune/trial_runner.py", line 459, in _process_trial result = self.trial_executor.fetch_result(trial) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/tune/ray_trial_executor.py", line 377, in fetch_result result = ray.get(trial_future[0], DEFAULT_GET_TIMEOUT) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/worker.py", line 1522, in get raise value.as_instanceof_cause() ray.exceptions.RayTaskError(KeyError): ray::PPO.train() (pid=13094, ip=10.0.2.217) File "python/ray/_raylet.pyx", line 447, in ray._raylet.execute_task File "python/ray/_raylet.pyx", line 425, in ray._raylet.execute_task.function_executor File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/trainer.py", line 477, in train raise e File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/trainer.py", line 463, in train result = Trainable.train(self) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/tune/trainable.py", line 254, in train result = self._train() File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/trainer_template.py", line 122, in _train fetches = self.optimizer.step() File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/optimizers/sync_samples_optimizer.py", line 71, in step self.standardize_fields) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/utils/sgd.py", line 111, in do_minibatch_sgd }, minibatch.count)))[policy_id] File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/evaluation/rollout_worker.py", line 619, in learn_on_batch info_out[pid] = policy.learn_on_batch(batch) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/policy/torch_policy.py", line 100, in learn_on_batch loss_out = self._loss(self, self.model, self.dist_class, train_batch) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py", line 112, in ppo_surrogate_loss print(train_batch["seq_lens"]) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/utils/tracking_dict.py", line 22, in __getitem__ value = dict.__getitem__(self, key) KeyError: 'seq_lens'

Reproduction (REQUIRED)

import ray
from ray import tune
from ray.rllib.agents import ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.policy.policy import TupleActions
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.tune.registry import register_env
import gym
from gym.spaces import Discrete, Box, Dict, MultiDiscrete
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.nn import Parameter
from torch import Tensor

def _make_f32_array(number):
    return np.array(number, dtype="float32")

class TorchMultiCategorical(ActionDistribution):
    """MultiCategorical distribution for MultiDiscrete action spaces."""

    @override(ActionDistribution)
    def __init__(self, inputs, model):
        input_lens = model.dist_input_lens
        inputs_splitted = inputs.split(input_lens, dim=1)
        self.cats = [
            torch.distributions.categorical.Categorical(logits=input_)
            for input_ in inputs_splitted
        ]

    @override(ActionDistribution)
    def sample(self):
        arr = [cat.sample() for cat in self.cats]
        ret = torch.stack(arr, dim=1)
        return ret

    @override(ActionDistribution)
    def logp(self, actions):
        # # If tensor is provided, unstack it into list
        if isinstance(actions, torch.Tensor):
            actions = torch.unbind(actions, dim=1)
        logps = torch.stack([cat.log_prob(act) for cat, act in zip(self.cats, actions)])
        return torch.sum(logps, dim=0)

    @override(ActionDistribution)
    def multi_entropy(self):
        return torch.stack([cat.entropy() for cat in self.cats], dim=1)

    @override(ActionDistribution)
    def entropy(self):
        return torch.sum(self.multi_entropy(), dim=1)

    @override(ActionDistribution)
    def multi_kl(self, other):
        return torch.stack(
            [
                torch.distributions.kl.kl_divergence(cat, oth_cat)
                for cat, oth_cat in zip(self.cats, other.cats)
            ],
            dim=1,
        )

    @override(ActionDistribution)
    def kl(self, other):
        return torch.sum(self.multi_kl(other), dim=1)

    @staticmethod
    @override(ActionDistribution)
    def required_model_output_shape(action_space, model_config):
        return np.sum(action_space.nvec)

class ReproEnv(gym.Env):
    def __init__(self, config):
        self.cur_pos = 0
        self.window_size = config["window_size"]
        self.need_reset = False
        self.action_space = MultiDiscrete([3, 2, 51, 10, 2])
        self.observation_space = Dict(
            {
                "lob": Box(low=-np.inf, high=np.inf, shape=(self.window_size, 40)),
                "unallocated_wealth": Box(low=0, high=1, shape=()),
                "taker_fees": Box(low=-1, high=1, shape=()),
                "maker_fees": Box(low=-1, high=1, shape=()),
                "order": Dict(
                    {
                        "side": Discrete(3),
                        "type": Discrete(2),
                        "size": Box(low=0, high=1, shape=()),
                        "price": Box(low=-np.inf, high=np.inf, shape=()),
                        "filled": Box(low=0, high=1, shape=()),
                    }
                ),
                "position": Dict(
                    {
                        "side": Discrete(3),
                        "size": Box(low=0, high=1, shape=()),
                        "entry_price": Box(low=-np.inf, high=np.inf, shape=()),
                        "unrealized_pnl": Box(low=-100, high=np.inf, shape=()),
                    }
                ),
            }
        )

    def reset(self):
        self.cur_pos = 0
        self.need_reset = False
        return self.step([0, 0, 0, 0, 0])[0]  # Noop

    def step(self, action):
        if self.need_reset:
            raise Exception("You need to reset this environment!")
        self.cur_pos += 1
        assert action in self.action_space, action
        if self.cur_pos >= 1000:
            done = True
        else:
            done = False
        info = {}
        if done:
            self.need_reset = True
        observation = {
            "lob": np.zeros((self.window_size, 40)),
            "taker_fees": _make_f32_array(0),
            "maker_fees": _make_f32_array(0),
            "unallocated_wealth": _make_f32_array(0),
            "order": {
                "side": 0,
                "type": 0,
                "size": _make_f32_array(0),
                "price": _make_f32_array(0),
                "filled": _make_f32_array(0),
            },
            "position": {
                "side": 0,
                "size": _make_f32_array(0),
                "entry_price": _make_f32_array(0),
                "unrealized_pnl": _make_f32_array(0),
            },
        }
        assert observation in self.observation_space, observation
        return observation, 0, done, info
        # return observation

class CNN(nn.Module):
    def __init__(self, dropout=0.2):
        super(CNN, self).__init__()

        self.conv1 = nn.Conv2d(1, 16, kernel_size=(1, 2), stride=(1, 2))
        self.conv2 = nn.Conv2d(16, 16, kernel_size=(4, 1))
        self.conv3 = nn.Conv2d(16, 16, kernel_size=(4, 1))

        self.conv4 = nn.Conv2d(16, 32, kernel_size=(1, 2), stride=(1, 2))
        self.conv5 = nn.Conv2d(32, 32, kernel_size=(4, 1))
        self.conv6 = nn.Conv2d(32, 32, kernel_size=(4, 1))

        self.conv7 = nn.Conv2d(32, 64, kernel_size=(1, 10))
        self.conv8 = nn.Conv2d(64, 64, kernel_size=(4, 1))
        self.conv9 = nn.Conv2d(64, 64, kernel_size=(4, 1))
        # Pad to preserve the length in the time domain
        self.pad = nn.ZeroPad2d((0, 0, 0, 3))
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x))
        x = F.leaky_relu(self.conv2(self.pad(x)))
        x = F.leaky_relu(self.conv3(self.pad(x)))
        x = F.leaky_relu(self.conv4(x))
        x = F.leaky_relu(self.conv5(self.pad(x)))
        x = F.leaky_relu(self.conv6(self.pad(x)))
        x = F.leaky_relu(self.conv7(x))
        x = F.leaky_relu(self.conv8(self.pad(x)))
        x = F.leaky_relu(self.conv9(self.pad(x)))
        x = self.dropout(x)
        return x

class TestNet(TorchModelV2, nn.Module):
    def init_hidden(self, hidden_size):
        h0 = self._value_branch[0].weight.new(1, hidden_size).zero_()
        c0 = self._value_branch[0].weight.new(1, hidden_size).zero_()
        return (h0, c0)

    def __init__(self, obs_space, action_space, num_outputs, config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, config, name)
        nn.Module.__init__(self)
        model_config = config["custom_options"]
        print("Model config:")
        print(model_config)
        dropout = model_config["dropout"]
        window_size = model_config["window_size"]
        self.cnn = CNN(dropout=dropout)
        print(f"Dropout: {dropout}")
        print(f"Window size: {window_size}")

        # Value function
        self._value_branch = nn.Sequential(
            nn.Linear(6, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 1),
        )
        # Policy: Signal
        self.long_lstm = nn.LSTM(64 * window_size, 64, batch_first=True)
        self.short_lstm = nn.LSTM(64 * window_size, 64, batch_first=True)

        self.long_025_lstm = nn.LSTM(64, 32, batch_first=True)
        self.long_025_fc = nn.Linear(32, 1)

        self.long_050_lstm = nn.LSTM(64, 32, batch_first=True)
        self.long_050_fc = nn.Linear(32, 1)

        self.long_075_lstm = nn.LSTM(64, 32, batch_first=True)
        self.long_075_fc = nn.Linear(32, 1)

        self.short_025_lstm = nn.LSTM(64, 32, batch_first=True)
        self.short_025_fc = nn.Linear(32, 1)

        self.short_050_lstm = nn.LSTM(64, 32, batch_first=True)
        self.short_050_fc = nn.Linear(32, 1)

        self.short_075_lstm = nn.LSTM(64, 32, batch_first=True)
        self.short_075_fc = nn.Linear(32, 1)
        # Policy: Brain
        self.dumb_fc = nn.Linear(6, 68)

        self.dist_input_lens = [3, 2, 51, 10, 2]
        self._cur_value = None

    @override(TorchModelV2)
    def get_initial_state(self):
        # make hidden states on same device as model
        long_lstm_h, long_lstm_c = self.init_hidden(64)
        short_lstm_h, short_lstm_c = self.init_hidden(64)

        long_025_lstm_h, long_025_lstm_c = self.init_hidden(32)
        long_050_lstm_h, long_050_lstm_c = self.init_hidden(32)
        long_075_lstm_h, long_075_lstm_c = self.init_hidden(32)

        short_025_lstm_h, short_025_lstm_c = self.init_hidden(32)
        short_050_lstm_h, short_050_lstm_c = self.init_hidden(32)
        short_075_lstm_h, short_075_lstm_c = self.init_hidden(32)

        initial_state = [
            long_lstm_h,
            long_lstm_c,
            short_lstm_h,
            short_lstm_c,
            long_025_lstm_h,
            long_025_lstm_c,
            long_050_lstm_c,
            long_050_lstm_c,
            long_075_lstm_h,
            long_075_lstm_h,
            short_025_lstm_h,
            short_025_lstm_c,
            short_050_lstm_c,
            short_050_lstm_c,
            short_075_lstm_h,
            short_075_lstm_h,
        ]
        return initial_state

    @override(TorchModelV2)
    def value_function(self):
        assert self._cur_value is not None, "must call forward() first"
        return self._cur_value

    def forward(self, input_dict, hidden_state, seq_lens):
        # if seq_lens is None:
        #     raise Exception("seq_lens is None")
        lob = input_dict["obs"]["lob"]
        batch_size, window_length, features = lob.size()
        # assert list(hidden_state[0].size()) == [1, 1, 64]
        # Unpack the hidden_state
        long_lstm_h = hidden_state[0]
        long_lstm_c = hidden_state[1]
        short_lstm_h = hidden_state[2]
        short_lstm_c = hidden_state[3]

        long_025_lstm_h = hidden_state[4]
        long_025_lstm_c = hidden_state[5]
        long_050_lstm_h = hidden_state[6]
        long_050_lstm_c = hidden_state[7]
        long_075_lstm_h = hidden_state[8]
        long_075_lstm_c = hidden_state[9]

        short_025_lstm_h = hidden_state[10]
        short_025_lstm_c = hidden_state[11]
        short_050_lstm_h = hidden_state[12]
        short_050_lstm_c = hidden_state[13]
        short_075_lstm_h = hidden_state[14]
        short_075_lstm_c = hidden_state[15]
        # Build the tuples
        long_lstm_hidden = (long_lstm_h.view(1, -1, 64), long_lstm_c.view(1, -1, 64))
        short_lstm_hidden = (short_lstm_h.view(1, -1, 64), short_lstm_c.view(1, -1, 64))
        long_025_lstm_hidden = (
            long_025_lstm_h.view(1, -1, 32),
            long_025_lstm_c.view(1, -1, 32),
        )
        long_050_lstm_hidden = (
            long_050_lstm_h.view(1, -1, 32),
            long_050_lstm_c.view(1, -1, 32),
        )
        long_075_lstm_hidden = (
            long_075_lstm_h.view(1, -1, 32),
            long_075_lstm_c.view(1, -1, 32),
        )
        short_025_lstm_hidden = (
            short_025_lstm_h.view(1, -1, 32),
            short_025_lstm_c.view(1, -1, 32),
        )
        short_050_lstm_hidden = (
            short_050_lstm_h.view(1, -1, 32),
            short_050_lstm_c.view(1, -1, 32),
        )
        short_075_lstm_hidden = (
            short_075_lstm_h.view(1, -1, 32),
            short_075_lstm_c.view(1, -1, 32),
        )

        c_in = lob.view(batch_size, 1, window_length, features)
        c_out = self.cnn(c_in)
        # Embeddings from the CNN, reshaped to be consummed by the LSTM
        embeddings = c_out.view(batch_size, 1, -1)

        long_out, long_lstm_hidden = self.long_lstm(embeddings, long_lstm_hidden)
        short_out, short_lstm_hidden = self.short_lstm(embeddings, short_lstm_hidden)
        # Now on to the tail LSTMs
        long_025_out, long_025_lstm_hidden = self.long_025_lstm(
            F.leaky_relu(long_out), long_025_lstm_hidden
        )
        long_050_out, long_050_lstm_hidden = self.long_050_lstm(
            F.leaky_relu(long_out), long_050_lstm_hidden
        )
        long_075_out, long_075_lstm_hidden = self.long_075_lstm(
            F.leaky_relu(long_out), long_075_lstm_hidden
        )

        short_025_out, short_025_lstm_hidden = self.short_025_lstm(
            F.leaky_relu(short_out), short_025_lstm_hidden
        )
        short_050_out, short_050_lstm_hidden = self.short_050_lstm(
            F.leaky_relu(short_out), short_050_lstm_hidden
        )
        short_075_out, short_075_lstm_hidden = self.short_075_lstm(
            F.leaky_relu(short_out), short_075_lstm_hidden
        )
        # Reshape the outputs of the tail LSTMs into (batch, hidden_size)
        long_025_out = long_025_out.view(batch_size, -1)
        long_050_out = long_050_out.view(batch_size, -1)
        long_075_out = long_075_out.view(batch_size, -1)

        short_025_out = short_025_out.view(batch_size, -1)
        short_050_out = short_050_out.view(batch_size, -1)
        short_075_out = short_075_out.view(batch_size, -1)
        # Fully connected at the end
        long_025_q = self.long_025_fc(F.leaky_relu(long_025_out))
        long_050_q = self.long_050_fc(F.leaky_relu(long_050_out))
        long_075_q = self.long_075_fc(F.leaky_relu(long_075_out))

        short_025_q = self.short_025_fc(F.leaky_relu(short_025_out))
        short_050_q = self.short_050_fc(F.leaky_relu(short_050_out))
        short_075_q = self.short_075_fc(F.leaky_relu(short_075_out))
        quantiles = [
            long_025_q,
            long_050_q,
            long_075_q,
            short_025_q,
            short_050_q,
            short_075_q,
        ]
        quantiles = torch.cat(quantiles, dim=1).view(batch_size, 6)
        new_hidden_state = [
            long_lstm_hidden[0],
            long_lstm_hidden[1],
            short_lstm_hidden[0],
            short_lstm_hidden[1],
            long_025_lstm_hidden[0],
            long_025_lstm_hidden[1],
            long_050_lstm_hidden[0],
            long_050_lstm_hidden[1],
            long_075_lstm_hidden[0],
            long_075_lstm_hidden[1],
            short_025_lstm_hidden[0],
            short_025_lstm_hidden[1],
            short_050_lstm_hidden[0],
            short_050_lstm_hidden[1],
            short_075_lstm_hidden[0],
            short_075_lstm_hidden[1],
        ]
        assert list(new_hidden_state[0].size()) == [
            1,
            list(new_hidden_state[0].size())[1],
            64,
        ], new_hidden_state[0].size()
        # Value function
        self._cur_value = self._value_branch(quantiles).squeeze(1)
        logits = self.dumb_fc(quantiles)
        return logits, new_hidden_state

ModelCatalog.register_custom_action_dist("torchmulticategorical", TorchMultiCategorical)
ModelCatalog.register_custom_model("test", TestNet)
register_env("test", lambda config: ReproEnv(config))

ray.init()
tune.run(
    ppo.PPOTrainer,
    config={
        "num_workers": 1,
        "env": "test",
        "log_level": "INFO",
        "use_pytorch": True,
        "num_gpus": 1,
        "vf_share_layers": True,
        "env_config": {"window_size": 100,},
        "model": {
            "custom_action_dist": "torchmulticategorical",
            "custom_model": "test",
            "custom_options": {
                "window_size": 100,
                "dropout": 0.2,
                "use_learned_hidden": True,
            },
        },
    },
)
ericl commented 4 years ago

I think this is since we don't support RNNs with pytorch yet, cc @sven1977

ludns commented 4 years ago

From the doc: "Similarly, you can create and register custom PyTorch models for use with PyTorch-based algorithms (e.g., A2C, PG, QMIX). See these examples of fully connected, convolutional, and recurrent torch models." It looks like you do! Here is a pytorch rnn in the codebase: https://github.com/ray-project/ray/blob/master/rllib/agents/qmix/model.py

sven1977 commented 4 years ago

I'll assign myself and take a look. Thanks for filing this!

sven1977 commented 4 years ago

Yes, we'll have to fix LSTM support for torch.

janblumenkamp commented 4 years ago

I have the same problem right now and would be happy to look into it. What has to be done to make it work?

sven1977 commented 4 years ago

Thanks for offering your help everyone! I'm taking a look right now. ... I'll keep you posted throughout the next few hours/days what I find.

janblumenkamp commented 4 years ago

Great, thanks! This is my current minimum example in case it helps to reproduce the error (latest wheel, Python 3.6):

import argparse
import math
import gym
from gym import spaces
from gym.utils import seeding
import numpy as np

from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_torch
from ray.rllib.models import ModelCatalog

torch, nn = try_import_torch()

class CartPoleStatelessEnv(gym.Env):
    # ... refer https://github.com/ray-project/ray/blob/5cebee68d681bebfd59255b811338d39e2cc2e7d/rllib/examples/cartpole_lstm.py

def _get_size(obs_space):
    return get_preprocessor(obs_space)(obs_space).size

class RNNModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)
        self.obs_size = _get_size(obs_space)
        self.rnn_hidden_dim = model_config["lstm_cell_size"]
        self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim)
        self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
        self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs)

        self.value_branch = nn.Linear(self.rnn_hidden_dim, 1)
        self._cur_value = None

    @override(TorchModelV2)
    def get_initial_state(self):
        # make hidden states on same device as model
        h = [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)]
        return h

    @override(TorchModelV2)
    def value_function(self):
        assert self._cur_value is not None, "must call forward() first"
        return self._cur_value

    @override(TorchModelV2)
    def forward(self, input_dict, hidden_state, seq_lens):
        x = nn.functional.relu(self.fc1(input_dict["obs_flat"].float()))
        h_in = hidden_state[0].reshape(-1, self.rnn_hidden_dim)
        h = self.rnn(x, h_in)
        q = self.fc2(h)
        self._cur_value = self.value_branch(h).squeeze(1)
        return q, [h]

if __name__ == "__main__":
    import ray
    from ray import tune

    ModelCatalog.register_custom_model("rnnmodel", RNNModel)
    tune.register_env("cartpole_stateless", lambda _: CartPoleStatelessEnv())

    ray.init()

    tune.run(
        "PPO",
        stop={"episode_reward_mean": 200},
        config={
            "use_pytorch": True,
            "model": {
                "custom_model": "rnnmodel",
                "lstm_use_prev_action_reward": "store_true",
                "lstm_cell_size": 20,
                "custom_options": {}
            },
            "num_sgd_iter": 5,
            "vf_share_layers": True,
            "vf_loss_coeff": 0.0001,
            "env": "cartpole_stateless",
        }
    )

With the error

Traceback (most recent call last):                                                                                                                                                                                 
  File "[...]/ray/tune/trial_runner.py", line 467, in _process_trial                                                                     
    result = self.trial_executor.fetch_result(trial)                                                                                                                                                               
  File "[...]/ray/tune/ray_trial_executor.py", line 381, in fetch_result                                                                 
    result = ray.get(trial_future[0], DEFAULT_GET_TIMEOUT)                                                                                                                                                         
  File "[...]/ray/worker.py", line 1505, in get                                                                                          
    raise value.as_instanceof_cause()                                                                                                                                                                              
ray.exceptions.RayTaskError(KeyError): ray::PPO.train() (pid=21658, ip=128.232.69.20)                                                                                                                              
  File "python/ray/_raylet.pyx", line 445, in ray._raylet.execute_task                                                                                                                                             
  File "python/ray/_raylet.pyx", line 423, in ray._raylet.execute_task.function_executor                                                                                                                           
  File "[...]/ray/rllib/agents/trainer.py", line 504, in train                                                                           
    raise e                                                                                                                                                                                                        
  File "[...]/ray/rllib/agents/trainer.py", line 490, in train                                                                           
    result = Trainable.train(self)                                                                                                                                                                                 
  File "[...]/ray/tune/trainable.py", line 261, in train                                                                                 
    result = self._train()                                                                                                                                                                                         
  File "[...]/ray/rllib/agents/trainer_template.py", line 150, in _train                                                                 
    fetches = self.optimizer.step()                                                                                                                                                                                
  File "[...]/ray/rllib/optimizers/sync_samples_optimizer.py", line 71, in step                                                          
    self.standardize_fields)                                                                                                                                                                                       
  File "[...]/ray/rllib/utils/sgd.py", line 115, in do_minibatch_sgd                                                                     
    }, minibatch.count)))[policy_id]                                                                                                                                                                               
  File "[...]/ray/rllib/evaluation/rollout_worker.py", line 632, in learn_on_batch
    info_out[pid] = policy.learn_on_batch(batch)
  File "[...]/ray/rllib/policy/torch_policy.py", line 132, in learn_on_batch
    loss_out = self._loss(self, self.model, self.dist_class, train_batch)
  File "[...]/ray/rllib/agents/ppo/ppo_torch_policy.py", line 120, in ppo_surrogate_loss
    max_seq_len = torch.max(train_batch["seq_lens"])
  File "[...]/ray/rllib/utils/tracking_dict.py", line 22, in __getitem__
    value = dict.__getitem__(self, key)
KeyError: 'seq_lens'
sven1977 commented 4 years ago

Yeah, I think I know what it is: In rllib/policy/tf_policy.py, we call: _get_loss_inputs_dict (Worker calls: Policy.lean_on_batch, which then calls Policy._build_learn_on_batch in the tf-case) and that calculates and adds the seq_lens. In pytorch, we don't seem to do anything comparable. I'll fix this now.

sven1977 commented 4 years ago

Almost there. Just some flaws now in the PPO loss concerning valid_mask. Apologies for the docs mentioning that we do generically support pytorch + LSTMs: We don't (yet)! There will be a PR (probably tomorrow), which will fix that for at least the standard PG-algos: PPO, PG, A2C/A3C, iff one uses a custom torch Model. Making the "use_lstm" auto-wrapping functionality work will be a follow-up PR.

sven1977 commented 4 years ago

There was quite some stuff missing for this to work. I'll do a WIP PR later today and post it here. Got the example running, but CartPole doesn't seem to learn with PPO + torch + LSTM. Will have to take a closer look.

sven1977 commented 4 years ago

Here is the WIP PR that makes the above minimal example run: https://github.com/ray-project/ray/pull/7797 Will add more tests and make sure the CartPole example learns as well.

sven1977 commented 4 years ago

Ok, #7797 is learning the RepeatInitialEnv example using PPO + torch + a custom torch model. Check out this example script (included in the PR; you need this PR for the code below to run and learn). Will be merged within the next few days.

import argparse

import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.examples.cartpole_lstm import CartPoleStatelessEnv
from ray.rllib.examples.custom_keras_rnn_model import RepeatInitialEnv
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.recurrent_torch_model import RecurrentTorchModel
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_torch
from ray.rllib.models import ModelCatalog
import ray.tune as tune

torch, nn = try_import_torch()

parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO")
parser.add_argument("--env", type=str, default="repeat_initial")
parser.add_argument("--stop", type=int, default=90)
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument("--lstm-cell-size", type=int, default=32)

class RNNModel(RecurrentTorchModel):
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        super().__init__(obs_space, action_space, num_outputs, model_config,
                         name)

        self.obs_size = get_preprocessor(obs_space)(obs_space).size
        self.rnn_hidden_dim = model_config["lstm_cell_size"]
        self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim)
        self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
        self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs)

        self.value_branch = nn.Linear(self.rnn_hidden_dim, 1)
        self._cur_value = None

    @override(ModelV2)
    def get_initial_state(self):
        # make hidden states on same device as model
        h = [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)]
        return h

    @override(ModelV2)
    def value_function(self):
        assert self._cur_value is not None, "must call forward() first"
        return self._cur_value

    @override(RecurrentTorchModel)
    def forward_rnn(self, inputs, state, seq_lens):
        """Feeds `inputs` (B x T x ..) through the Gru Unit.

        Returns the resulting outputs as a sequence (B x T x ...).
        Values are stored in self._cur_value in simple (B) shape (where B
        contains both the B and T dims!).

        Returns:
            NN Outputs (B x T x ...) as sequence.
            The state batch as a List of one item (only one hidden state b/c
                Gru).
        """
        x = nn.functional.relu(self.fc1(inputs))
        h = state[0]
        outs = []
        for i in range(torch.max(seq_lens)):
            h = self.rnn(x[:, i], h)
            outs.append(h)
        outs = torch.stack(outs, 0)
        q = self.fc2(outs)
        self._cur_value = torch.reshape(self.value_branch(outs), [-1])
        return q, [outs[-1]]

if __name__ == "__main__":
    args = parser.parse_args()

    ray.init(num_cpus=args.num_cpus or None)
    ModelCatalog.register_custom_model("rnn", RNNModel)
    tune.register_env("repeat_initial", lambda c: RepeatInitialEnv())
    tune.register_env("cartpole_stateless", lambda c: CartPoleStatelessEnv())

    config = {
        "num_workers": 0,
        "num_envs_per_worker": 20,
        "gamma": 0.9,
        "entropy_coeff": 0.001,
        "use_pytorch": True,
        "model": {
            "custom_model": "rnn",
            "lstm_use_prev_action_reward": "store_true",
            "lstm_cell_size": args.lstm_cell_size,
            "custom_options": {}
        },
        "lr": 0.0003,
        "num_sgd_iter": 5,
        "vf_loss_coeff": 1e-5,
        "env": args.env,
    }

    tune.run(
        args.run,
        stop={"episode_reward_mean": args.stop},
        config=config,
    )
sven1977 commented 4 years ago

@justinglibert (see post above). @janblumenkamp and @justinglibert Thanks for you help, guys!

janblumenkamp commented 4 years ago

Awesome, thank you very much Sven! I will try it!

sven1977 commented 4 years ago

Updated the PR once more. Switched out the Gru for an LSTM and this works much better now on our test envs.

sven1977 commented 4 years ago

The PR that fixes this problem (https://github.com/ray-project/ray/pull/7797) has been merged.

Please see this example code here for a learning example with PPO: https://github.com/ray-project/ray/blob/master/rllib/examples/custom_torch_rnn_model.py

Closing this issue.