RobertTLange / gymnax

RL Environments in JAX 🌍
Apache License 2.0
577 stars 54 forks source link

ValueError: mutable default <class 'jaxlib.xla_extension.ArrayImpl'> for field reward_timestep is not allowed: use default_factory #50

Closed carlosgmartin closed 1 year ago

carlosgmartin commented 1 year ago

I'm getting the following error:

$ python3 -c "import gymnax"
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/usr/local/lib/python3.11/site-packages/gymnax/__init__.py", line 1, in <module>
    from .registration import make, registered_envs
  File "/usr/local/lib/python3.11/site-packages/gymnax/registration.py", line 1, in <module>
    from .environments import (
  File "/usr/local/lib/python3.11/site-packages/gymnax/environments/__init__.py", line 9, in <module>
    from .bsuite import (
  File "/usr/local/lib/python3.11/site-packages/gymnax/environments/bsuite/__init__.py", line 3, in <module>
    from .discounting_chain import DiscountingChain
  File "/usr/local/lib/python3.11/site-packages/gymnax/environments/bsuite/discounting_chain.py", line 17, in <module>
    @struct.dataclass
     ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/flax/struct.py", line 101, in dataclass
    data_clz = dataclasses.dataclass(frozen=True)(clz) # type: ignore
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/Cellar/python@3.11/3.11.2_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/dataclasses.py", line 1210, in wrap
    return _process_class(cls, init, repr, eq, order, unsafe_hash,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/Cellar/python@3.11/3.11.2_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/dataclasses.py", line 958, in _process_class
    cls_fields.append(_get_field(cls, name, type, kw_only))
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/Cellar/python@3.11/3.11.2_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/dataclasses.py", line 815, in _get_field
    raise ValueError(f'mutable default {type(f.default)} for field '
ValueError: mutable default <class 'jaxlib.xla_extension.ArrayImpl'> for field reward_timestep is not allowed: use default_factory

Version information:

RobertTLange commented 1 year ago

I believe this should be fixed now on the main branch. Can you check this on your machine?

carlosgmartin commented 1 year ago

@RobertTLange Works now.