ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.09k stars 5.6k forks source link

[RLlib] [Bug] concatenating obs_space with action _space as input space in RNNSAC build_q_model method causes shape mismatch building rnn model #21457

Closed wildsky95 closed 1 year ago

wildsky95 commented 2 years ago

Search before asking

Ray Component

RLlib

What happened + What you expected to happen

``Hi, im trying to train a multiagent RNNsac with my custom environment. but the problem is i get a shape mismatch error, i tried to resolve this on my own. but i get that when building the q_model the obs_shape and action space gets concatenated and therefore the model shape gets a shape of action shape + ob shape, and in training the shape mismatch occurs. i cant quite understand why the build_q_model is concatenating the action and obs.

my custom env's observation space is (9640,) and action space is (4031,) and both are continous values, so with concatenation in q model building i get a shape error. im literally trying the RNNSAC test algorithm to run the model. and also it's worth mentioning that works perfectly well with multiagent cartpole but it doesn't work with custom env. ofcourse i tested my custom multi agent env with PPO and PG and its works good!!! the error i get is :

Traceback (most recent call last):
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 773, in setup
    self._init(self.config, self.env_creator)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 873, in _init
    raise NotImplementedError
NotImplementedError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/wildsky/Dropbox/AI-AoI-FeLSA/Simulation/marl_test/testSAC.py", line 116, in <module>
    trainer = sac.RNNSACTrainer(config=config, env="multi_agent_aoi")
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/sac/sac.py", line 187, in __init__
    super().__init__(*args, **kwargs)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 690, in __init__
    super().__init__(config, logger_creator, remote_checkpoint_dir,
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/tune/trainable.py", line 122, in __init__
    self.setup(copy.deepcopy(self.config))
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 788, in setup
    self.workers = self._make_workers(
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 1822, in _make_workers
    return WorkerSet(
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/evaluation/worker_set.py", line 123, in __init__
    self._local_worker = self._make_worker(
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/evaluation/worker_set.py", line 479, in _make_worker
    worker = cls(
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 587, in __init__
    self._build_policy_map(
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1550, in _build_policy_map
    self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/policy/policy_map.py", line 143, in create_policy
    self[policy_id] = class_(observation_space, action_space,
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/policy/policy_template.py", line 280, in __init__
    self._initialize_loss_from_dummy_batch(
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/policy/policy.py", line 799, in _initialize_loss_from_dummy_batch
    self.compute_actions_from_input_dict(
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 294, in compute_actions_from_input_dict
    return self._compute_action_helper(input_dict, state_batches,
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/utils/threading.py", line 21, in wrapper
    return func(self, *a, **k)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 908, in _compute_action_helper
    self.action_distribution_fn(
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/sac/rnnsac_torch_policy.py", line 175, in action_distribution_fn
    _, q_state_out = model.get_q_values(model_out, states_in["q"], seq_lens)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/sac/rnnsac_torch_model.py", line 100, in get_q_values
    return self._get_q_value(model_out, actions, self.q_net, state_in,
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/sac/rnnsac_torch_model.py", line 91, in _get_q_value
    out, state_out = net(model_out, state_in, seq_lens)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/models/torch/recurrent_net.py", line 187, in forward
    wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/models/torch/fcnet.py", line 124, in forward
    self._features = self._hidden_layers(self._last_flat_in)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/models/torch/misc.py", line 160, in forward
    return self._model(x)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 96, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/functional.py", line 1847, in linear
    return torch._C._nn.linear(input, weight, bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x9640 and 13671x10)

i use this code to train :

from ray.tune.registry import register_env
from ray.rllib.env.multi_agent_env import make_multi_agent
from env_rllib import Environment

from ray.rllib.models import ModelCatalog
import ray.rllib.agents.sac as sac
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import check_compute_single_action, \
    framework_iterator

from rnn_model import TorchRNNModel, RNNModel

tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()

MultiAgentAOI = make_multi_agent(Environment)

ModelCatalog.register_custom_model("lstm_model", TorchRNNModel)
ModelCatalog.register_custom_model("lstm_model_tf", RNNModel)
register_env("multi_agent_aoi" , lambda x : MultiAgentAOI({"num_agents": 5}))

config = sac.RNNSAC_DEFAULT_CONFIG.copy()
config["num_workers"] = 0  # Run locally.

config["model"] = {
    "max_seq_len": 100,
}
config["env"]= "multi_agent_aoi"
config["policy_model"] = {
    # "custom_model": "lstm_model",
    "fcnet_hiddens": [10],
    "use_lstm": True,
    "lstm_cell_size": 64,

    "lstm_use_prev_action": True,
    "lstm_use_prev_reward": True,
                          }
config["Q_model"] = {
    # "custom_model": "lstm_model",
    "fcnet_hiddens": [10],
    "use_lstm": True,

    "lstm_cell_size": 64,

    "lstm_use_prev_action": True,
    "lstm_use_prev_reward": True,

}

config["prioritized_replay"] = True

config["burn_in"] = 20
config["zero_init_states"] = True

config["lr"] = 5e-4

num_iterations = 1

for _ in framework_iterator(config, frameworks="torch"):
            trainer = sac.RNNSACTrainer(config=config, env="multi_agent_aoi")
            for i in range(num_iterations):
                results = trainer.train()
                print(results)

i dont quite understand this part of building q_model method :

def build_q_model(self, obs_space, action_space, num_outputs,
                      q_model_config, name):
        """Builds one of the (twin) Q-nets used by this SAC.

        Override this method in a sub-class of SACTFModel to implement your
        own Q-nets. Alternatively, simply set `custom_model` within the
        top level SAC `Q_model` config key to make this default implementation
        of `build_q_model` use your custom Q-nets.

        Returns:
            TorchModelV2: The TorchModelV2 Q-net sub-model.
        """
        self.concat_obs_and_actions = False
        if self.discrete:
            input_space = obs_space
        else:
            orig_space = getattr(obs_space, "original_space", obs_space)
            if isinstance(orig_space, Box) and len(orig_space.shape) == 1:
                input_space = Box(
                    float("-inf"),
                    float("inf"),
                    shape=(orig_space.shape[0] + action_space.shape[0], ))
                self.concat_obs_and_actions = True

thanks in advance for your guidance.

Versions / Dependencies

v2.0, v1.9

wildsky95 commented 2 years ago

Hi, is there any development on this issue ? @avnishn @sven1977

avnishn commented 2 years ago

Hey sorry about the late responses -- I haven't gotten to this in time -- I'm going to assign @sven1977 as well, so that we can get your issue resolved ASAP.

bektaskemal commented 2 years ago

Hi, is there any update on this? @avnishn @sven1977

avnishn commented 2 years ago

i cant quite understand why the build_q_model is concatenating the action and obs.

In RL literature, the Q function's parameters are states and actions. When we represent that as a neural network in code, we concatenate the observations and actions together in order to represent that the Q function is a function of these 2 things.

In the case that your environment is a discrete environment, the q function actually isn't a q function, and is instead a value function, which only takes as a parameter observation.

so my best guess here is that your issue has something to do with the observations and actions not being concatenated when they're being passed to the _get_q_value function of the RNNSAC torch model. This is probably because at some point, self.concat_obs_and_actions is being set to false, meaning that its trying to pass only the observation, instead of the observation concatenated with the action to your q function.

jamuus commented 2 years ago

Hey,

I've run into the same issue.

From the debugging I've done it seems to be the calls toget_q_values in action_distribution_fn don't have actions passed in, so _get_q_value doesn't concatenate anything with the observation and the dimension mismatch occurs.

    _, q_state_out = model.get_q_values(model_out, states_in["q"], seq_lens)
    if model.twin_q_net:
        _, twin_q_state_out = model.get_twin_q_values(
            model_out, states_in["twin_q"], seq_lens
        )

With #23814 and passing in input_dict['actions'] it progresses further but I'm seeing other seemingly unrelated issues. I also have no idea if those are the actions expected at this point in the algo.

It appears the rnnsac implementation hasn't been tested with continuous actions, would be good if someone knowledgable of how its supposed to be could take a look, I've seen great performance with the torch implementation of the normal SAC so far.