Closed ludns closed 4 years ago
I think this is since we don't support RNNs with pytorch yet, cc @sven1977
From the doc: "Similarly, you can create and register custom PyTorch models for use with PyTorch-based algorithms (e.g., A2C, PG, QMIX). See these examples of fully connected, convolutional, and recurrent torch models." It looks like you do! Here is a pytorch rnn in the codebase: https://github.com/ray-project/ray/blob/master/rllib/agents/qmix/model.py
I'll assign myself and take a look. Thanks for filing this!
Yes, we'll have to fix LSTM support for torch.
I have the same problem right now and would be happy to look into it. What has to be done to make it work?
Thanks for offering your help everyone! I'm taking a look right now. ... I'll keep you posted throughout the next few hours/days what I find.
Great, thanks! This is my current minimum example in case it helps to reproduce the error (latest wheel, Python 3.6):
import argparse
import math
import gym
from gym import spaces
from gym.utils import seeding
import numpy as np
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_torch
from ray.rllib.models import ModelCatalog
torch, nn = try_import_torch()
class CartPoleStatelessEnv(gym.Env):
# ... refer https://github.com/ray-project/ray/blob/5cebee68d681bebfd59255b811338d39e2cc2e7d/rllib/examples/cartpole_lstm.py
def _get_size(obs_space):
return get_preprocessor(obs_space)(obs_space).size
class RNNModel(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
nn.Module.__init__(self)
self.obs_size = _get_size(obs_space)
self.rnn_hidden_dim = model_config["lstm_cell_size"]
self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim)
self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs)
self.value_branch = nn.Linear(self.rnn_hidden_dim, 1)
self._cur_value = None
@override(TorchModelV2)
def get_initial_state(self):
# make hidden states on same device as model
h = [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)]
return h
@override(TorchModelV2)
def value_function(self):
assert self._cur_value is not None, "must call forward() first"
return self._cur_value
@override(TorchModelV2)
def forward(self, input_dict, hidden_state, seq_lens):
x = nn.functional.relu(self.fc1(input_dict["obs_flat"].float()))
h_in = hidden_state[0].reshape(-1, self.rnn_hidden_dim)
h = self.rnn(x, h_in)
q = self.fc2(h)
self._cur_value = self.value_branch(h).squeeze(1)
return q, [h]
if __name__ == "__main__":
import ray
from ray import tune
ModelCatalog.register_custom_model("rnnmodel", RNNModel)
tune.register_env("cartpole_stateless", lambda _: CartPoleStatelessEnv())
ray.init()
tune.run(
"PPO",
stop={"episode_reward_mean": 200},
config={
"use_pytorch": True,
"model": {
"custom_model": "rnnmodel",
"lstm_use_prev_action_reward": "store_true",
"lstm_cell_size": 20,
"custom_options": {}
},
"num_sgd_iter": 5,
"vf_share_layers": True,
"vf_loss_coeff": 0.0001,
"env": "cartpole_stateless",
}
)
With the error
Traceback (most recent call last):
File "[...]/ray/tune/trial_runner.py", line 467, in _process_trial
result = self.trial_executor.fetch_result(trial)
File "[...]/ray/tune/ray_trial_executor.py", line 381, in fetch_result
result = ray.get(trial_future[0], DEFAULT_GET_TIMEOUT)
File "[...]/ray/worker.py", line 1505, in get
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(KeyError): ray::PPO.train() (pid=21658, ip=128.232.69.20)
File "python/ray/_raylet.pyx", line 445, in ray._raylet.execute_task
File "python/ray/_raylet.pyx", line 423, in ray._raylet.execute_task.function_executor
File "[...]/ray/rllib/agents/trainer.py", line 504, in train
raise e
File "[...]/ray/rllib/agents/trainer.py", line 490, in train
result = Trainable.train(self)
File "[...]/ray/tune/trainable.py", line 261, in train
result = self._train()
File "[...]/ray/rllib/agents/trainer_template.py", line 150, in _train
fetches = self.optimizer.step()
File "[...]/ray/rllib/optimizers/sync_samples_optimizer.py", line 71, in step
self.standardize_fields)
File "[...]/ray/rllib/utils/sgd.py", line 115, in do_minibatch_sgd
}, minibatch.count)))[policy_id]
File "[...]/ray/rllib/evaluation/rollout_worker.py", line 632, in learn_on_batch
info_out[pid] = policy.learn_on_batch(batch)
File "[...]/ray/rllib/policy/torch_policy.py", line 132, in learn_on_batch
loss_out = self._loss(self, self.model, self.dist_class, train_batch)
File "[...]/ray/rllib/agents/ppo/ppo_torch_policy.py", line 120, in ppo_surrogate_loss
max_seq_len = torch.max(train_batch["seq_lens"])
File "[...]/ray/rllib/utils/tracking_dict.py", line 22, in __getitem__
value = dict.__getitem__(self, key)
KeyError: 'seq_lens'
Yeah, I think I know what it is: In rllib/policy/tf_policy.py
, we call: _get_loss_inputs_dict
(Worker calls: Policy.lean_on_batch
, which then calls Policy._build_learn_on_batch
in the tf-case) and that calculates and adds the seq_lens. In pytorch, we don't seem to do anything comparable. I'll fix this now.
Almost there. Just some flaws now in the PPO loss concerning valid_mask
.
Apologies for the docs mentioning that we do generically support pytorch + LSTMs: We don't (yet)! There will be a PR (probably tomorrow), which will fix that for at least the standard PG-algos: PPO, PG, A2C/A3C, iff one uses a custom torch Model. Making the "use_lstm" auto-wrapping functionality work will be a follow-up PR.
There was quite some stuff missing for this to work. I'll do a WIP PR later today and post it here. Got the example running, but CartPole doesn't seem to learn with PPO + torch + LSTM. Will have to take a closer look.
Here is the WIP PR that makes the above minimal example run: https://github.com/ray-project/ray/pull/7797 Will add more tests and make sure the CartPole example learns as well.
Ok, #7797 is learning the RepeatInitialEnv example using PPO + torch + a custom torch model. Check out this example script (included in the PR; you need this PR for the code below to run and learn). Will be merged within the next few days.
import argparse
import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.examples.cartpole_lstm import CartPoleStatelessEnv
from ray.rllib.examples.custom_keras_rnn_model import RepeatInitialEnv
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.recurrent_torch_model import RecurrentTorchModel
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_torch
from ray.rllib.models import ModelCatalog
import ray.tune as tune
torch, nn = try_import_torch()
parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO")
parser.add_argument("--env", type=str, default="repeat_initial")
parser.add_argument("--stop", type=int, default=90)
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument("--lstm-cell-size", type=int, default=32)
class RNNModel(RecurrentTorchModel):
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super().__init__(obs_space, action_space, num_outputs, model_config,
name)
self.obs_size = get_preprocessor(obs_space)(obs_space).size
self.rnn_hidden_dim = model_config["lstm_cell_size"]
self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim)
self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs)
self.value_branch = nn.Linear(self.rnn_hidden_dim, 1)
self._cur_value = None
@override(ModelV2)
def get_initial_state(self):
# make hidden states on same device as model
h = [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)]
return h
@override(ModelV2)
def value_function(self):
assert self._cur_value is not None, "must call forward() first"
return self._cur_value
@override(RecurrentTorchModel)
def forward_rnn(self, inputs, state, seq_lens):
"""Feeds `inputs` (B x T x ..) through the Gru Unit.
Returns the resulting outputs as a sequence (B x T x ...).
Values are stored in self._cur_value in simple (B) shape (where B
contains both the B and T dims!).
Returns:
NN Outputs (B x T x ...) as sequence.
The state batch as a List of one item (only one hidden state b/c
Gru).
"""
x = nn.functional.relu(self.fc1(inputs))
h = state[0]
outs = []
for i in range(torch.max(seq_lens)):
h = self.rnn(x[:, i], h)
outs.append(h)
outs = torch.stack(outs, 0)
q = self.fc2(outs)
self._cur_value = torch.reshape(self.value_branch(outs), [-1])
return q, [outs[-1]]
if __name__ == "__main__":
args = parser.parse_args()
ray.init(num_cpus=args.num_cpus or None)
ModelCatalog.register_custom_model("rnn", RNNModel)
tune.register_env("repeat_initial", lambda c: RepeatInitialEnv())
tune.register_env("cartpole_stateless", lambda c: CartPoleStatelessEnv())
config = {
"num_workers": 0,
"num_envs_per_worker": 20,
"gamma": 0.9,
"entropy_coeff": 0.001,
"use_pytorch": True,
"model": {
"custom_model": "rnn",
"lstm_use_prev_action_reward": "store_true",
"lstm_cell_size": args.lstm_cell_size,
"custom_options": {}
},
"lr": 0.0003,
"num_sgd_iter": 5,
"vf_loss_coeff": 1e-5,
"env": args.env,
}
tune.run(
args.run,
stop={"episode_reward_mean": args.stop},
config=config,
)
@justinglibert (see post above). @janblumenkamp and @justinglibert Thanks for you help, guys!
Awesome, thank you very much Sven! I will try it!
Updated the PR once more. Switched out the Gru for an LSTM and this works much better now on our test envs.
The PR that fixes this problem (https://github.com/ray-project/ray/pull/7797) has been merged.
Please see this example code here for a learning example with PPO: https://github.com/ray-project/ray/blob/master/rllib/examples/custom_torch_rnn_model.py
Closing this issue.
What is the problem?
Installed ray with the nightly wheel. I wrote a custom env, model, and action distribution. I attempt to train it with PPO but there is a key error in one of the internal object used by RLLib (the batch dict with "seq_lens" that is used for masking recurrent model when backpropagating)
2020-02-18 16:11:55,261 ERROR trial_runner.py:513 -- Trial PPO_test_49f2c33a: Error processing event. Traceback (most recent call last): File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/tune/trial_runner.py", line 459, in _process_trial result = self.trial_executor.fetch_result(trial) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/tune/ray_trial_executor.py", line 377, in fetch_result result = ray.get(trial_future[0], DEFAULT_GET_TIMEOUT) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/worker.py", line 1522, in get raise value.as_instanceof_cause() ray.exceptions.RayTaskError(KeyError): ray::PPO.train() (pid=13094, ip=10.0.2.217) File "python/ray/_raylet.pyx", line 447, in ray._raylet.execute_task File "python/ray/_raylet.pyx", line 425, in ray._raylet.execute_task.function_executor File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/trainer.py", line 477, in train raise e File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/trainer.py", line 463, in train result = Trainable.train(self) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/tune/trainable.py", line 254, in train result = self._train() File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/trainer_template.py", line 122, in _train fetches = self.optimizer.step() File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/optimizers/sync_samples_optimizer.py", line 71, in step self.standardize_fields) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/utils/sgd.py", line 111, in do_minibatch_sgd }, minibatch.count)))[policy_id] File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/evaluation/rollout_worker.py", line 619, in learn_on_batch info_out[pid] = policy.learn_on_batch(batch) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/policy/torch_policy.py", line 100, in learn_on_batch loss_out = self._loss(self, self.model, self.dist_class, train_batch) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py", line 112, in ppo_surrogate_loss print(train_batch["seq_lens"]) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/utils/tracking_dict.py", line 22, in __getitem__ value = dict.__getitem__(self, key)
KeyError: 'seq_lens'Reproduction (REQUIRED)