corl-team / xland-minigrid

JAX-accelerated Meta-Reinforcement Learning Environments Inspired by XLand and MiniGrid 🏎️
Apache License 2.0
162 stars 12 forks source link

`ValueError: mutable default…` on the library import for python > 3.10 #3

Closed ntoxeg closed 5 months ago

ntoxeg commented 7 months ago

Steps to reproduce: import xminigrid

Error: ValueError: mutable default <class 'jaxlib.xla_extension.ArrayImpl'> for field position is not allowed: use default_factory

Env: xminigrid = "^0.0.2” Python ver: 3.11.4

Howuhh commented 7 months ago

Hi @ntoxeg , can you also tell me the version of jax? I would also try using python <= 3.10, higher than that I haven't tested yet. Locally, in docker and in Colab I can not reproduce the error. Seems like this error is connected to the jax and jaxlib, not the xminigrid.

ntoxeg commented 7 months ago

jax==0.4.21, I will test with 3.10 tomorrow. I am not 100% sure but I do think 3.11 is strict with mutable defaults in data classes (which is an actual error, just not caught in earlier versions).

Howuhh commented 7 months ago

The thing is, all defaults in the dataclasses are jax arrays, which are immutable but I could be wrong.

ntoxeg commented 7 months ago

If they are not hashable, they are treated as mutable, so field(default_factory=…) probably has to be used. From Python documentation:

The assumption is that if a value is unhashable, it is mutable. This is a partial solution, but it does protect against many common errors.

Howuhh commented 7 months ago

Indeed! I'll fix and test this as soon as possible. For now you can use version 3.10