google-deepmind / acme

A library of reinforcement learning components and agents
Apache License 2.0
3.52k stars 426 forks source link

Use num_batch_dims=0 to deal with observations containing scalars #231

Open wookayin opened 2 years ago

wookayin commented 2 years ago

When a (nested) observation contain a scalar value on some axes with dimension 1, batch_concat() will crash due to a shape mismatch error (e.g., cannot concatenate tensors of shape [100, 1] and []).

This happens because batch_concat by default assumes the input tensors are batched, but when computing the statistics the batch dimensions (axis=0) already have been eliminated.

Note: This applies to SAC agent. PPO agent already has a correct implementation (comparison here).

wookayin commented 2 years ago

One additional favor: please rebase(squash) when merging, rather than creating a merge commit.