Closed xiangyuy closed 2 years ago
Hello, I was just wondering have you tried to flatten the output of the policy module and then reshape it later when it is passed to the step function of the environment. Just a simple check to pinpoint the error.
@UweGensheimer Thank you for your suggestion, flatten the output is kinda complicated so instead I tried to specify the size of my MultiDiscrete
to be one (i.e., policy function would return ({'logits': jnp.array})
) and it seems to be running without error. So I think using a MultiDiscrete
action space seems to be the problem here, not sure what to do next...
Thanks for reporting this!
I can confirm that this is a proper bug. It has to do with the way variates are pre/post-processed.
I'll have a closer look at it as soon as I have time, which is either tonight or tomorrow.
Thank you so much for your speedy response! So I see #22 has passed all tests. Will it be merged to main
soon or some more tests and examinations need to be done?
Update: I edited _composite.py according to 82bcd674c7b6ff2d8efd3f3d0e8e0b68184c4a84 and it seems to be running without error. Thank you again for your help!
Hi @xiangyuy, thanks for your patience. PR #22 is merged. I bumped the version number, because I also fixed a few other little things in the same PR.
First of all, thank you for developing this package and I really like the modular design. I am a bit new to RL and the JAX ecosystem so my question my be a bit naive. I am currently doing a baseline study with my customized
gym
environment andVanillaPG
but I encounter this bug shown below and I could not figure it out. My understanding is that it is complaining that the shape oflog_pi
should not be(4,)
. But I do have aMultiDiscrete
action space and its correspondinglog_pi
should be something like(4,)
or(1, 4)
. I also attached the output when I callcoax.Policy.example_data(env)
and my policy function definition below if that helps explain the situation.So my questions are:
I would appreciate any feedback. Thank you!
Error message
Example data
Policy function