corl-team / xland-minigrid

JAX-accelerated Meta-Reinforcement Learning Environments Inspired by XLand and MiniGrid 🏎️
Apache License 2.0
192 stars 15 forks source link

Fix issue #3. #5

Closed floringogianu closed 8 months ago

floringogianu commented 8 months ago

This addresses #3 . Checked on python 3.12 and jax 0.4.23.

Howuhh commented 8 months ago

Hi, @floringogianu! Thank you for this! Do you think it would be better to use field from flax.struct? It seems to be just a wrapper, but it is potentially better compatible with PyTreeNode

floringogianu commented 8 months ago

Hi, you are right, it should use flax.struct.field. I did the changes, let me know if I should squash the commits.