When a (nested) observation contain a scalar value on some axes
with dimension 1, batch_concat() will crash due to a shape mismatch
error (e.g., cannot concatenate tensors of shape [100, 1] and []).
This happens because batch_concat by default assumes the input
tensors are batched, but when computing the statistics
the batch dimensions (axis=0) already have been eliminated.
Note: This applies to SAC agent. PPO agent already has a correct implementation (comparison here).
When a (nested) observation contain a scalar value on some axes with dimension 1, batch_concat() will crash due to a shape mismatch error (e.g., cannot concatenate tensors of shape [100, 1] and []).
This happens because batch_concat by default assumes the input tensors are batched, but when computing the statistics the batch dimensions (axis=0) already have been eliminated.
Note: This applies to SAC agent. PPO agent already has a correct implementation (comparison here).