Closed ethanabrooks closed 3 years ago
Thanks for reporting this Ethan,
From your description, it seems that the problem will affect anyone who uses Reverb and TFP in Jax. So, it sounds like dm-reverb would be the right place to address this. I've notified the Reverb team and sent a link to this thread.
Thanks for doing that. Should I open a new issue with them?
As I said, I've notified the team, but unfortunately the thread is internal and I can't add you to it.
If you prefer to be in the loop, feel free to open an issue with them. In such case, please post a link to the reverb issue here so that I can make sure that the related issues are properly cross-referenced.
Hi Ethan- Have you tried using the reverb nightly? Unfortunately there is some behavior we have needed in the nightlies, otherwise we'd like to avoid this as otherwise issues like yours crop up---the nightlies tend to propagate.
But https://github.com/deepmind/acme/commit/a619d1836c58b558193fcfa715fd8fb29569d5fa moved to using dm-reverb-nightly
and just the other day in https://github.com/deepmind/acme/commit/c7d3970d898c7b01be19d570a2740520984bc2e3 we pinned this to a known working combination of all the nightlies (this is not strictly necessary, but occasionally the nightlies do introduce a temporary hiccup).
Ok I was able to get things to run. I've documented the dependencies that I used in this pyproject.toml
:
[tool.poetry]
name = "impala"
version = "0.1.0"
description = ""
authors = ["Ethan Brooks <ethanabrooks@gmail.com>"]
[tool.poetry.dependencies]
python = ">=3.7,<4.0"
dm-reverb-nightly = "^0.3.0-alpha.20210701"
tfp-nightly = "^0.14.0-alpha.20210630"
dm-haiku = "^0.0.4"
jax = "^0.2.16"
dm-acme = {extras = ["envs", "jax"], version = "^0.2.1"}
tf-nightly-cpu = {url = "https://files.pythonhosted.org/packages/80/c3/3e276aca325a81ad61eaeaa80e74565df035f3c8aa462f758a99a2bbc6d7/tf_nightly_cpu-2.7.0.dev20210701-cp38-cp38-manylinux2010_x86_64.whl"}
jaxlib = {url = "https://storage.googleapis.com/jax-releases/cuda112/jaxlib-0.1.65+cuda112-cp38-none-manylinux2010_x86_64.whl"}
optax = "^0.0.8"
pytz = "^2021.1"
dm-sonnet = "^2.0.0"
trfl = "^1.1.0"
[tool.poetry.dev-dependencies]
ipython = "^7.25.0"
ipdb = "^0.13.9"
black = "^21.6b0"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
And this poetry.lock
file.
Hi,
I am attempting to run your
jax
implementation of IMPALA using this script.I noticed that
dm-reverb
depends ontensorflow (>=2.5.0,<2.6.0)
. However,acme/jax/networks/distributional.py
depends ontensorflow-probability
. However, I have found that, per this issuejax
only gets along withtfp-nightly
-- otherwise I get the error reported in that issue. Meanwhile usingtfp-nightly
withouttf-nightly
results in the following error:Looking at the tensorflow source code, it is clear that
linear_operator.py
was updated in a recent tensorflow release. However,dm-reverb
prevents me from updating totf-nightly
, as previously discussed.Thank you for your help!