Closed george-skal closed 3 years ago
Hi @george-skal,
The issue is in centralized_critic_postprocessing
. The SampleBatch concat_samples
is concatenating in the batch dimension.
print(SampleBatch.concat_samples(
[opponent_n_batch for _, opponent_n_batch in other_agent_batches.values()])["obs"].shape)
(2000, 242)
This should do the trick:
sample_batch[OPPONENT_OBS] = np.concatenate([opponent_batch[SampleBatch.CUR_OBS] for
_, opponent_batch in other_agent_batches.values()],
-1)
sample_batch[OPPONENT_ACTION] = np.concatenate([opponent_batch[SampleBatch.ACTIONS] for
_, opponent_batch in other_agent_batches.values()],
-1)
Hi @mvindiola1 ,
Thank you very much for your help. I tried your solution and now the code works with tensorflow, but not with torch. I am using the code from the example centralized_critic.py and with torch I get a similar to the previous error:
RuntimeError: Sizes of tensors must match except in dimension 1. Got 32 and 1 in dimension 0 (The offending index is 1)
I see that in the central_value_function I have in torch:
```
obs.shape: torch.Size([32, 242])
opponent_obs.shape: torch.Size([1, 968])
opponent_actions.shape: torch.Size([1, 8])
```
while in tensorflow they are:
(?, 242)
(?, 968)
(?, 8)
so maybe the problem is that the first dimension in torch is not None. Do you have any idea how to fix it?
Thanks in advance. Best regards, George
@george-skal,
Glad that worked for you. This time the error is coming from the else branch when the loss is being initialized to infer the trajectory view information. When you were creating the dummy opponent obs and action you forgot to include the batch size information. This should fix the issue for you.
if (pytorch and hasattr(policy, "compute_central_vf")) or \
(not pytorch and policy.loss_initialized()):
...
else:
# Policy hasn't been initialized yet, use zeros
batch_size = sample_batch[SampleBatch.CUR_OBS].shape[0]
sample_batch[OPPONENT_OBS] = np.zeros((batch_size,obs_dim * (n_pursuers - 1)))
sample_batch[OPPONENT_ACTION] = np.zeros((batch_size,act_dim * (n_pursuers - 1)))
Sorry for the late reply. These solutions work fine so the issue can get closed.
Thanks, George
Hello, I am using RAY to customize a centralized critical network. As a beginner, I have encountered many doubts and problems. Do I still need to override ModelV2 when customizing? Thank you! Looking forward to your reply!
class CentralizedCriticModel(TFModelV2):
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
return self.model.forward(input_dict, state, seq_lens)
@override(ModelV2)
def value_function(self):
return self.model.value_function() # not used
Hi! I am using ray 1.2.0 and Python 3.6 on Ubuntu 18.04
I am trying a centralised critic PPO for the waterworld environment from Pettingzoo[sisl] https://www.pettingzoo.ml/sisl/waterworld.
The error I get is:
but if I try (out of curiocity) to double the shape dim in the layer one setting:
opp_obs = tf.keras.layers.Input(shape=(2*opp_obs_dim, ), name="opp_obs")
I get the error:
that seems odd to me since this is the correct found shape (opp_obs_dim = 968).
I think that the problem might be in the initialisaton code (I used the solution from here: https://github.com/ray-project/ray/issues/8011
The environment has:
Observation space: Box(low=np.float32(-np.sqrt(2)), high=np.float32(2 * np.sqrt(2)), shape=(self._obs_dim,), dtype=np.float32)
Action space: Box(low=np.float32(-self._max_accel), high=np.float32(self._max_accel), shape=(2,), dtype=np.float32)
The observation space shape of 1 agent is 242. I am new to Rllib so I am not sure if it is a bug or not, but something I don’t understand, therefore I would appreciate any help. Also I am not using one_hot since the environment is continuous, but I am not sure about it and I would be happy if someone could clarify this, or inform about other things I should change.
Also, @korbinian-hoermann since I was based partly on the issue you opened here https://github.com/ray-project/ray/issues/12851 , please let me know if I do something wrong.
I get a similar error with torch.
My code:
Full error message:
Thank in advance!
Best, George