Open gustavodemari opened 10 months ago
How are you @gustavodemari ?
In my opinion, it is not a bug.
See this link, flatten_trajectories
creates next_obs
and dones
automatically.
In this code which is used in GAIL for training, you can see flatten_trajectories
s family, which is called flatten_trajectories_with_rew
.
So, you just choose about dones
and next_obs
in initialize BasicRewardNet
, whether to use them or not.
Bug description
RewardNet
predict_processed
method only works usingstate, action, next_state and done
attributes, despite trained using onlystate, action
.For example, the BasicRewardNet by default trains a network using only
state, action
, i.e, $R(s, a)$. However, thepredict_processed
needsstate, action, next_state and done
attributes.Thus, maybe
predict_processed
should havenext_state and done
optional (see below) and inside the method should check if next_state and done are None to change the behavior.Steps to reproduce
Environment
pip freeze --all
:Pip Freeze
absl-py==2.0.0 aiohttp==3.9.1 aiosignal==1.3.1 alembic==1.13.1 anyio==4.2.0 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.3.0 asttokens==2.4.1 async-lru==2.0.4 async-timeout==4.0.3 attrs==23.2.0 Babel==2.14.0 backcall==0.2.0 beautifulsoup4==4.12.2 bleach==6.1.0 cachetools==5.3.2 certifi==2023.11.17 cffi==1.16.0 charset-normalizer==3.3.2 cloudpickle==3.0.0 colorama==0.4.6 colorlog==6.8.0 comm==0.2.1 contourpy==1.1.1 cycler==0.12.1 Cython==3.0.7 dataclasses==0.6 datasets==2.16.1 debugpy==1.8.0 decorator==5.1.1 defusedxml==0.7.1 dfa==2.1.2 dill==0.3.7 docopt==0.6.2 exceptiongroup==1.2.0 execnet==2.0.2 executing==2.0.1 Farama-Notifications==0.0.4 fastjsonschema==2.19.1 filelock==3.13.1 fonttools==4.47.0 fqdn==1.5.1 frozenlist==1.4.1 fsspec==2023.10.0 funcy==1.18 gitdb==4.0.11 GitPython==3.1.40 google-auth==2.26.1 google-auth-oauthlib==1.0.0 GPy==1.10.0 GPyOpt==1.2.6 greenlet==3.0.3 grpcio==1.60.0 gym==0.26.2 gym-notices==0.0.8 gymnasium==0.29.1 h5py==3.10.0 huggingface-hub==0.20.1 huggingface-sb3==3.0 idna==3.6 imitation==1.0.0 importlib-metadata==7.0.1 importlib-resources==6.1.1 iniconfig==2.0.0 ipykernel==6.28.0 ipython==8.12.3 isoduration==20.11.0 istype==0.2.0 jedi==0.19.1 Jinja2==3.1.2 joblib==1.3.2 json5==0.9.14 jsonpickle==3.0.2 jsonpointer==2.4 jsonschema==4.20.0 jsonschema-specifications==2023.12.1 jupyter-events==0.9.0 jupyter-lsp==2.2.1 jupyter_client==8.6.0 jupyter_core==5.7.0 jupyter_server==2.12.2 jupyter_server_terminals==0.5.1 jupyterlab==4.0.10 jupyterlab_pygments==0.3.0 jupyterlab_server==2.25.2 kiwisolver==1.4.5 lazytree==0.3.2 lenses==0.5.0 Mako==1.3.0 Markdown==3.5.1 markdown-it-py==3.0.0 MarkupSafe==2.1.3 matplotlib==3.7.4 matplotlib-inline==0.1.6 mdurl==0.1.2 mistune==3.0.2 mpmath==1.3.0 multidict==6.0.4 multiprocess==0.70.15 munch==4.0.0 mypy-extensions==1.0.0 nbclient==0.9.0 nbconvert==7.14.0 nbformat==5.9.2 nest-asyncio==1.5.8 networkx==3.1 notebook_shim==0.2.3 numpy==1.24.4 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.18.1 nvidia-nvjitlink-cu12==12.3.101 nvidia-nvtx-cu12==12.1.105 oauthlib==3.2.2 optuna==3.5.0 orderedset==2.0.3 overrides==7.4.0 packaging==23.2 pandas==2.0.3 pandocfilters==1.5.0 paramz==0.9.5 parso==0.8.3 pexpect==4.9.0 pickleshare==0.7.5 pillow==10.2.0 pip==23.3.1 pkgutil_resolve_name==1.3.10 platformdirs==4.1.0 pluggy==1.3.0 probabilistic-automata==0.4.2 prometheus-client==0.19.0 prompt-toolkit==3.0.43 protobuf==4.25.1 psutil==5.9.7 ptyprocess==0.7.0 pure-eval==0.2.2 py==1.11.0 py-cpuinfo==9.0.0 py-spy==0.3.14 pyarrow==14.0.2 pyarrow-hotfix==0.6 pyasn1==0.5.1 pyasn1-modules==0.3.0 pycparser==2.21 pygame==2.5.2 Pygments==2.17.2 pyparsing==3.1.1 pyrsistent==0.20.0 pytest==7.4.4 pytest-forked==1.6.0 pytest-xdist==2.5.0 python-dateutil==2.8.2 python-json-logger==2.0.7 pytz==2023.3.post1 PyYAML==6.0.1 pyzmq==25.1.2 referencing==0.32.1 requests==2.31.0 requests-oauthlib==1.3.1 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 rich==13.7.0 rpds-py==0.16.2 rsa==4.9 sacred==0.8.5 scikit-learn==1.3.2 scipy==1.10.1 seals==0.2.1 Send2Trash==1.8.2 setuptools==68.2.2 singledispatch==4.1.0 six==1.16.0 smmap==5.0.1 sniffio==1.3.0 soupsieve==2.5 SQLAlchemy==2.0.25 stable-baselines3==2.2.1 stack-data==0.6.3 structlog==23.3.0 sympy==1.12 tensorboard==2.14.0 tensorboard-data-server==0.7.2 terminado==0.18.0 threadpoolctl==3.2.0 tinycss2==1.2.1 tomli==2.0.1 torch==2.1.2 tornado==6.4 tqdm==4.66.1 traitlets==5.14.1 triton==2.1.0 types-python-dateutil==2.8.19.20240106 typing-inspect==0.5.0 typing_extensions==4.9.0 tzdata==2023.4 uri-template==1.3.0 urllib3==2.1.0 wasabi==1.1.2 wcwidth==0.2.12 webcolors==1.13 webencodings==0.5.1 websocket-client==1.7.0 Werkzeug==3.0.1 wheel==0.41.2 wrapt==1.16.0 xeus-python==0.15.12 xeus-python-shell==0.5.0 xxhash==3.4.1 yarl==1.9.4 zipp==3.17.0