Open GoodarzMehr opened 10 months ago
Update: I think the issue can be fixed (for PPO at least) by changing line 175 of rllib/models/torch/complex_input_net.py
into this
post_fcnet_hiddens = model_config.get("post_fcnet_hiddens", [])
if post_fcnet_hiddens:
self.num_outputs = post_fcnet_hiddens[-1]
else:
self.num_outputs = concat_size
since concat_size
is the size of the model ouput before the final FC hidden layers. For RNNSAC, what seems to have worked in addition to the change above was adding this to line 96 of rllib/algorithms/sac/rnnsac_torch_model.py
if actions is None:
actions = model_out['prev_actions']
and changing line 370 of rllib/algorithms/sac/rnnsac_torch_policy.py
to
q_tp1, _ = target_model.get_q_values(
That said, even with small replay buffer sizes, RNNSAC seems to gobble up RAM so much so that it causes workers to quit due to memory pressure. I would appreciate it if someone could verify these changes.
@sven1977 We should discuss if and how we could support this in the new stack
What happened + What you expected to happen
I have been using RLlib with a multi-agent CARLA environment (adapted from this integration) where I have a tuple observation space:
Training with PPO or SAC is without any errors when using the following model configuration:
However, when I add LSTM to the model, i.e. change it to this (and use either PPO or RNNSAC):
I get this error:
1408 seems to be the concatenated output of the CNN (128x3x3) and the FC layers (256), though my understanding is that it has to go through the post FC hidden layer before getting passed to the LSTM. I'm not sure exactly where the dimensions go wrong, and I appreciate your help in resolving this.
Versions / Dependencies
Ray 2.9.0 Torch 1.10.1+cu113 Python 3.8.10
Reproduction script
I'm using the following training script:
with the following configuration file:
The CarlaEnv environment I'm using is not publicly available, though I think the issue can be reproduced with a dummy environment having the same observation space.
Issue Severity
High: It blocks me from completing my task.