luchris429 / purejaxrl

Really Fast End-to-End Jax RL Implementations
Apache License 2.0
738 stars 62 forks source link

NormalizeVecObservation Wrapper Shape Mismatch for Mean and Var #21

Open bheijden opened 7 months ago

bheijden commented 7 months ago

Hi,

The mean and var in the NormalizeVecObservation wrapper located here are shaped as (NUM_ENVS, ) + obs.shape. I think they should be shaped like a single observation, that is, obs[0].shape, considering they're supposed to calculate a running average across NUM_ENVS. This approach would match the shape used in the reward normalization wrapper found here.

While this doesn't seem to affect performance, since the mean and variance are correctly computed across the batch, it unnecessarily increases memory use and has caused some unexpected issues for me, especially when saving the normalization state along with train_state.

luchris429 commented 7 months ago

Ahh, good point! Do you think you could submit a PR? It should be a quick fix. I'll do it if you don't have time.