google-deepmind / acme

A library of reinforcement learning components and agents
Apache License 2.0
3.48k stars 424 forks source link

Dependency trouble #317

Open Timisorean opened 5 months ago

Timisorean commented 5 months ago

Greetings! I am having a lot of trouble installing acme because of dependency mismatches.

When I install the latest pip package with pip install dm-acme[jax,tf] I get:

Traceback (most recent call last):
  File "...", line 4, in <module>
    from acme.agents.jax import dqn
  File .../python3.9/site-packages/acme/agents/jax/dqn/__init__.py", line 18, in <module>
    from acme.agents.jax.dqn.actor import behavior_policy
  File ".../python3.9/site-packages/acme/agents/jax/dqn/actor.py", line 20, in <module>
    from acme.agents.jax import actor_core as actor_core_lib
  File ".../python3.9/site-packages/acme/agents/jax/actor_core.py", line 22, in <module>
    from acme.jax import networks as networks_lib
  File ".../python3.9/site-packages/acme/jax/networks/__init__.py", line 45, in <module>
    from acme.jax.networks.multiplexers import CriticMultiplexer
  File .../python3.9/site-packages/acme/jax/networks/multiplexers.py", line 20, in <module>
    from acme.jax import utils
  File ".../python3.9/site-packages/acme/jax/utils.py", line 190, in <module>
    devices: Optional[Sequence[jax.xla.Device]] = None,
  File ".../python3.9/site-packages/jax/_src/deprecations.py", line 54, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax' has no attribute 'xla'

trying to use jax.dqn. Apparently the jax/jaxlib version is too high. I saw that you already have fixated an older version in 7560b96543eff8f5e04d8a57dcca8545dd17f0ac. But if I try to install acme directly from git I get:

ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.3 (from dm-acme[jax]) (from versions: 0.4.6, 0.4.7, 0.4.9, 0.4.10, 0.4.11, 0.4.12, 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25, 0.4.26)
ERROR: No matching distribution found for jaxlib==0.4.3

And yes, the 0.4.3 version seems to be missing (https://pypi.org/project/jaxlib/#history). Even when I try to install this specifc version also directly from git, I also get errors:

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
chex 0.1.86 requires jax>=0.4.16, but you have jax 0.4.3 which is incompatible.
flax 0.8.2 requires jax>=0.4.19, but you have jax 0.4.3 which is incompatible.
orbax-checkpoint 0.5.9 requires jax>=0.4.9, but you have jax 0.4.3 which is incompatible.

It would be great if someone could provide a pip freeze or some other way to get acme working. Thanks

HarrisonFah commented 5 months ago

I was having similar issues trying to run their tutorial in Google Collab but managed to finally get it working by using:

pip install jaxlib==0.4.20 pip install git+https://github.com/deepmind/acme.git#egg=dm-acme[tf,envs] --use-deprecated=legacy-resolver