Closed ntoxeg closed 5 months ago
Hi @ntoxeg , can you also tell me the version of jax? I would also try using python <= 3.10, higher than that I haven't tested yet. Locally, in docker and in Colab I can not reproduce the error. Seems like this error is connected to the jax and jaxlib, not the xminigrid.
jax==0.4.21
, I will test with 3.10 tomorrow. I am not 100% sure but I do think 3.11 is strict with mutable defaults in data classes (which is an actual error, just not caught in earlier versions).
The thing is, all defaults in the dataclasses are jax arrays, which are immutable but I could be wrong.
If they are not hashable, they are treated as mutable, so field(default_factory=…)
probably has to be used.
From Python documentation:
The assumption is that if a value is unhashable, it is mutable. This is a partial solution, but it does protect against many common errors.
Indeed! I'll fix and test this as soon as possible. For now you can use version 3.10
Steps to reproduce:
import xminigrid
Error:
ValueError: mutable default <class 'jaxlib.xla_extension.ArrayImpl'> for field position is not allowed: use default_factory
Env:
xminigrid = "^0.0.2”
Python ver: 3.11.4