Open danielgafni opened 2 years ago
Thanks for pointing out the bug!
I did a test on train_mnist.py
with your PR (set obs_normalization=True
in trainer), however, I still got an error.
Did you not have problems in your tests?
Hmm, no, the code is working correctly in my project. I didn't test it with the examples tho. Will do it and fix the errors.
Ok, so 2 more dimensions are present in the obs_buffer
in the MNIST example:
>>> obs_buffer.shape
(1, 64, 1024, 28, 28, 1)
Their meaning is: [?, pop_size, batch_size, height, width, channels]
While
>>> running_mean.shape
(784,)
My code handles the last dimensions, they are expected. But the first 2 are causing the error.
MNIST
task has obs_size
of (28, 28, 1)
. However, the actual obs_buffer
also has the batch_size
dim introduced by the sample_batch
function. The ObsNormalizer
doesn't know anything about the batch size. Seems like it needs to have another argument, something like reduce_dims
, where we would specify our custom batch_size dimension. Maybe you can suggest another fix? I hope I explained the problem clean enough.hey @lerrytang! how any update on this issue?
for example, take a look at the brax
implementation:
https://github.com/google/brax/blob/main/brax/training/normalization.py
They have a num_leading_batch_dims
parameter for the normalizer. Seems like evojax
can do the same?
Hey! I found a bug in the observations normalization code. The bug occurs when the observations are not a flat array, but a multi-dim array. This happens because the obs_normalizer params are stored as a flat array. The code fails in this case. Here is the fix for this bug.