vwxyzjn / cleanrl

High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
http://docs.cleanrl.dev
Other
5.26k stars 602 forks source link

DDPG JAX breaks with python ~3.7 #309

Closed vwxyzjn closed 9 months ago

vwxyzjn commented 1 year ago

Problem Description

Running poetry run python cleanrl/ddpg_continuous_action_jax.py with python ~3.7 gets

Traceback (most recent call last):
  File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-0hpcRfYV-py3.7/lib/python3.7/site-packages/flax/linen/module.py", line 375, in wrapped
    hash_value = hash_fn(self)
  File "<string>", line 3, in __hash__
TypeError: unhashable type: 'numpy.ndarray'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "cleanrl/ddpg_continuous_action_jax.py", line 228, in <module>
    actions = actor.apply(actor_state.params, obs)
  File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-0hpcRfYV-py3.7/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-0hpcRfYV-py3.7/lib/python3.7/site-packages/jax/_src/api.py", line 531, in cache_miss
    donated_invars=donated_invars, inline=inline, keep_unused=keep_unused)
  File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-0hpcRfYV-py3.7/lib/python3.7/site-packages/jax/core.py", line 1963, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-0hpcRfYV-py3.7/lib/python3.7/site-packages/jax/core.py", line 1979, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-0hpcRfYV-py3.7/lib/python3.7/site-packages/jax/core.py", line 689, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-0hpcRfYV-py3.7/lib/python3.7/site-packages/jax/_src/dispatch.py", line 234, in _xla_call_impl
    keep_unused, *arg_specs)
  File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-0hpcRfYV-py3.7/lib/python3.7/site-packages/jax/linear_util.py", line 282, in memoized_fun
    cache = fun_caches.setdefault(fun.f, {})
  File "/home/costa/.pyenv/versions/3.7.8/lib/python3.7/weakref.py", line 489, in setdefault
    return self.data.setdefault(ref(key, self._remove),default)
  File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-0hpcRfYV-py3.7/lib/python3.7/site-packages/flax/linen/module.py", line 379, in wrapped
    f'Module={self}') from exc
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Failed to hash Flax Module.  The module probably contains unhashable attributes.  Module=Actor(
    # attributes
    action_dim = 6
    action_scale = array([[1., 1., 1., 1., 1., 1.]], dtype=float32)
    action_bias = array([[0., 0., 0., 0., 0., 0.]], dtype=float32)
)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "cleanrl/ddpg_continuous_action_jax.py", line 228, in <module>
    actions = actor.apply(actor_state.params, obs)
  File "/home/costa/.pyenv/versions/3.7.8/lib/python3.7/weakref.py", line 489, in setdefault
    return self.data.setdefault(ref(key, self._remove),default)
TypeError: Failed to hash Flax Module.  The module probably contains unhashable attributes.  Module=Actor(
    # attributes
    action_dim = 6
    action_scale = array([[1., 1., 1., 1., 1., 1.]], dtype=float32)
    action_bias = array([[0., 0., 0., 0., 0., 0.]], dtype=float32)
)

Should we drop 3.7 support?

Checklist

joaogui1 commented 1 year ago

I'm in favor of dropping it

vwxyzjn commented 9 months ago

Dropped 3.7 support