Open christianjcc opened 6 months ago
according to https://imitation.readthedocs.io/en/latest/, imitation is built on and compatible with Stable Baselines 3 (SB3), would it have trouble running with Ray RLlib, or are there any plans to support imitation with RLlib?
Bug description
I was testing out the imitation learning library with a custom gym environment and ran into a shortcoming in imitation/util/util.py. I get the error message provided below.
`Traceback (most recent call last): File "/home/anaconda3/envs/env/lib/python3.8/site-packages/gymnasium/envs/registration.py", line 802, in make env = env_creator(**env_spec_kwargs) TypeError: init() missing 1 required positional argument: 'config'
During handling of the above exception, another exception occurred:
Traceback (most recent call last): File "runner.py", line 118, in
main()
File "runner.py", line 58, in main
env = make_vec_env(
File "/home/anaconda3/envs/educagent_libtraffic/lib/python3.8/site-packages/imitation/util/util.py", line 117, in make_vec_env
tmp_env = gym.make(env_name)
File "/home/anaconda3/envs/env/lib/python3.8/site-packages/gymnasium/envs/registration.py", line 814, in make
raise type(e)(
TypeError: init() missing 1 required positional argument: 'config' was raised from the environment creator for custom/mycustom-env with kwargs ({})
`
Below is an example of how I define, register the custom gym environment, and pass along the config dictionary to the env instance, `
Define your config dictionary
config = { "parameters": "example", "render_mode": "human", }
env_name="custom/mycustom-env" gym.register( id=env_name, entry_point=MyCustomEnv, max_episode_steps=500, )
I can get it to work with the way the imitation library has implemented it's gym.make instance by doing the following instead:
Define a factory function that returns the custom environment class with the desired configuration
`
This occurs because the following variable is defined as follows:
tmp_env = gym.make(env_name)
As implemented here: https://github.com/HumanCompatibleAI/imitation/blob/master/src/imitation/util/util.py#L117
instead of:
tmp_env = gym.make(env_name, **env_make_kwargs)
to pass along the kwargs as defined in, https://gymnasium.farama.org/api/registry/#gymnasium.make
By implementing the suggestion above, it will help users avoid having to define factory function, and simplifying the steps.
Environment
pip freeze --all
: absl-py==0.15.0 aiohttp==3.9.5 aiosignal==1.2.0 ale-py==0.8.1 alembic==1.13.1 anyio==3.7.1 argon2-cffi==21.3.0 argon2-cffi-bindings==21.2.0 asttokens==2.4.1 astunparse==1.6.3 async-generator==1.10 async-timeout==4.0.3 attrs==22.2.0 AutoROM==0.4.2 AutoROM.accept-rom-license==0.6.1 backcall==0.2.0 beautifulsoup4==4.12.3 black==22.6.0 bleach==4.1.0 bottle==0.12.25 cached-property==1.5.2 cachetools==4.2.4 certifi==2021.5.30 cffi==1.15.1 cfgv==3.4.0 charset-normalizer==2.0.12 clang==5.0 click==8.0.4 cloudpickle==2.2.1 codecov==2.1.13 codespell==2.1.0 colorama==0.4.5 colorlog==6.8.2 comm==0.1.4 commonmark==0.9.1 conan==1.60.1 coverage==6.4.4 cycler==0.11.0 darglint==1.8.1 datasets==2.19.1 debugpy==1.8.1 decorator==4.4.2 defusedxml==0.7.1 Deprecated==1.2.13 dill==0.3.8 distlib==0.3.7 distro==1.8.0 dm-tree==0.1.8 docker-pycreds==0.4.0 docopt==0.6.2 docstring-parser==0.13 entrypoints==0.4 exceptiongroup==1.2.1 execnet==2.1.1 executing==2.0.1 Farama-Notifications==0.0.4 fasteners==0.19 fastjsonschema==2.19.1 filelock==3.7.1 fire==0.4.0 flake8==4.0.1 flake8-blind-except==0.2.1 flake8-builtins==1.5.3 flake8-commas==2.1.0 flake8-debugger==4.1.2 flake8-docstrings==1.6.0 flake8-isort==4.1.2.post0 flatbuffers==1.12 fonttools==4.34.4 frozenlist==1.2.0 fsspec==2024.3.1 gast==0.3.3 gitdb==4.0.11 GitPython==3.1.43 google-api-python-client==1.12.8 google-auth==2.29.0 google-auth-httplib2==0.1.0 google-auth-oauthlib==0.4.6 google-crc32c==1.3.0 google-pasta==0.2.0 greenlet==3.0.3 grpcio==1.43.0 gym==0.26.2 gym-notices==0.0.8 gymnasium==0.29.1 gymnasium-notices==0.0.1 h5py==2.10.0 highway-env==1.8.2 httplib2==0.20.2 huggingface-hub==0.23.0 huggingface-sb3==3.0 hypothesis==6.54.6 identify==2.5.36 idna==3.4 imageio==2.15.0 imageio-ffmpeg==0.4.9 imitation==1.0.0 importlab==0.8.1 importlib-resources==5.4.0 importlib_metadata==7.1.0 iniconfig==2.0.0 ipykernel==6.15.3 ipython==8.12.3 ipython-genutils==0.2.0 ipywidgets==7.8.1 isort==5.13.2 jedi==0.17.2 Jinja2==3.1.4 joblib==1.4.2 jsonpickle==3.0.4 jsonschema==3.2.0 jupyter==1.0.0 jupyter-client==6.1.12 jupyter-console==6.4.2 jupyter-server==1.24.0 jupyter-server-mathjax==0.2.6 jupyter_core==5.7.2 jupyterlab-pygments==0.1.2 jupyterlab-widgets==1.1.7 keras==2.6.0 Keras-Preprocessing==1.1.2 kfp==1.8.9 kfp-pipeline-spec==0.1.13 kfp-server-api==1.7.1 kiwisolver==1.3.1 kubernetes==18.20.0 libcst==1.1.0 lz4==3.1.10 Mako==1.3.3 Markdown==3.3.7 MarkupSafe==2.0.1 matplotlib==3.3.4 matplotlib-inline==0.1.7 mccabe==0.6.1 memory-profiler==0.61.0 mistune==3.0.2 moviepy==1.0.3 mpmath==1.3.0 msgpack==1.0.5 multidict==6.0.5 multiprocess==0.70.16 munch==4.0.0 mypy==0.991 mypy-extensions==1.0.0 nbclient==0.5.13 nbconvert==7.16.4 nbdime==4.0.1 nbformat==5.10.4 nest-asyncio==1.5.8 networkx==2.5.1 ninja==1.11.1.1 node-semver==0.6.1 nodeenv==1.8.0 notebook==6.4.10 numpy==1.21.0 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.20.5 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.1.105 oauthlib==3.2.2 opencv-python==4.9.0.80 opt-einsum==3.3.0 optuna==3.6.1 packaging==21.3 pandas==1.4.4 pandocfilters==1.5.0 parso==0.7.1 patch-ng==1.17.4 pathspec==0.9.0 pathtools==0.1.2 pexpect==4.8.0 pickleshare==0.7.5 Pillow==8.4.0 pip==24.0 platformdirs==2.6.2 plotly==5.18.0 pluggy==1.5.0 pluginbase==1.0.1 pre-commit==3.5.0 proglog==0.1.10 prometheus-client==0.17.1 promise==2.3 prompt-toolkit==3.0.36 protobuf==3.20.3 psutil==5.9.8 ptyprocess==0.7.0 pure-eval==0.2.2 py==1.11.0 py-cpuinfo==9.0.0 pyarrow==16.0.0 pyarrow-hotfix==0.6 pyasn1==0.5.0 pyasn1-modules==0.3.0 pycnite==2023.10.11 pycocotools==2.0.4 pycodestyle==2.8.0 pycparser==2.21 pydantic==1.8.2 pydocstyle==6.3.0 pydot==2.0.0 pyflakes==2.4.0 pygame==2.5.2 Pygments==2.14.0 PyJWT==2.4.0 pyparsing==3.1.1 pypiserver==2.0.1 pyrsistent==0.18.0 pytest==7.1.3 pytest-cov==3.0.0 pytest-forked==1.6.0 pytest-timeout==2.1.0 pytest-xdist==2.5.0 pytest_notebook==0.8.0 python-dateutil==2.8.2 pytype==2023.9.27 pytz==2023.3.post1 PyWavelets==1.1.1 PyYAML==6.0 pyzmq==25.1.1 qtconsole==5.2.2 QtPy==2.0.1 ray==2.0.1 requests==2.27.1 requests-oauthlib==1.3.1 requests-toolbelt==0.9.1 rich==12.6.0 rsa==4.9 sacred==0.8.5 scikit-image==0.17.2 scikit-learn==1.3.2 scipy==1.9.3 seals==0.2.1 Send2Trash==1.8.2 sentry-sdk==2.1.1 setproctitle==1.3.3 setuptools==69.5.1 setuptools-scm==7.0.5 Shimmy==0.2.1 shortuuid==1.0.13 six==1.15.0 smmap==5.0.1 sniffio==1.3.1 snowballstemmer==2.2.0 sortedcontainers==2.4.0 soupsieve==2.5 SQLAlchemy==2.0.30 stable_baselines3==2.3.2 stack-data==0.6.3 strip-hints==0.1.10 sympy==1.12 tabulate==0.8.10 tenacity==8.3.0 tensorboard==2.11.2 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.1 tensorboardX==2.6.2.2 tensorflow==2.2.0 tensorflow-estimator==2.2.0 termcolor==1.1.0 terminado==0.12.1 terminaltables==3.1.10 testpath==0.6.0 threadpoolctl==3.5.0 tifffile==2020.9.3 tinycss2==1.3.0 tokenize-rt==5.2.0 toml==0.10.2 tomli==1.2.3 torch==2.3.0 tornado==6.1 tqdm==4.64.1 traitlets==5.14.3 triton==2.3.0 typed-ast==1.5.5 typer==0.9.0 typing-inspect==0.9.0 typing_extensions==4.11.0 uritemplate==3.0.1 urllib3==1.26.17 validators==0.20.0 virtualenv==20.17.1 wandb==0.12.21 wasabi==1.1.2 wcwidth==0.2.9 webencodings==0.5.1 websocket-client==1.2.1 Werkzeug==2.0.3 wheel==0.43.0 widgetsnbextension==3.6.6 wrapt==1.12.1 xxhash==3.4.1 yarl==1.9.4 zipp==3.6.0