google / evojax

Apache License 2.0
834 stars 85 forks source link

fix support for multi-dim observations #22

Open danielgafni opened 2 years ago

danielgafni commented 2 years ago

Hey! I found a bug in the observations normalization code. The bug occurs when the observations are not a flat array, but a multi-dim array. This happens because the obs_normalizer params are stored as a flat array. The code fails in this case. Here is the fix for this bug.

lerrytang commented 2 years ago

Thanks for pointing out the bug! I did a test on train_mnist.py with your PR (set obs_normalization=True in trainer), however, I still got an error. Did you not have problems in your tests?

danielgafni commented 2 years ago

Hmm, no, the code is working correctly in my project. I didn't test it with the examples tho. Will do it and fix the errors.

danielgafni commented 2 years ago

Ok, so 2 more dimensions are present in the obs_buffer in the MNIST example:

>>> obs_buffer.shape
(1, 64, 1024, 28, 28, 1)

Their meaning is: [?, pop_size, batch_size, height, width, channels]

While

>>> running_mean.shape
(784,)

My code handles the last dimensions, they are expected. But the first 2 are causing the error.

  1. I don't know where does the very first (1) dimension come from. Is this the number of agents? ...
  2. Looks like the third dimensions, the batch_size, is causing another error. The MNIST task has obs_size of (28, 28, 1). However, the actual obs_buffer also has the batch_size dim introduced by the sample_batch function. The ObsNormalizer doesn't know anything about the batch size. Seems like it needs to have another argument, something like reduce_dims, where we would specify our custom batch_size dimension. Maybe you can suggest another fix? I hope I explained the problem clean enough.
danielgafni commented 2 years ago

hey @lerrytang! how any update on this issue?

danielgafni commented 2 years ago

for example, take a look at the brax implementation:

https://github.com/google/brax/blob/main/brax/training/normalization.py

They have a num_leading_batch_dims parameter for the normalizer. Seems like evojax can do the same?