RLE-Foundation / RLeXplore

RLeXplore provides stable baselines of exploration methods in reinforcement learning, such as intrinsic curiosity module (ICM), random network distillation (RND) and rewarding impact-driven exploration (RIDE).
https://docs.rllte.dev/
MIT License
352 stars 16 forks source link

RLeXplore with Stable Baseline3 example issue #16

Open edofazza opened 4 months ago

edofazza commented 4 months ago

When I am running the following code:

import torch as th

from rllte.xplore.reward import RND
from rllte.env import make_mario_env
from rllte.agent import PPO, DDPG

if __name__ == '__main__':
    n_steps: int = 2048 * 16
    device = 'cuda' if th.cuda.is_available() else 'cpu'
    envs = make_mario_env('SuperMarioBros-1-1-v0', device=device, num_envs=1,
                          asynchronous=False, frame_stack=4, gray_scale=True)
    print(device, envs.observation_space, envs.action_space)
    # create the intrinsic reward module
    irs = RND(envs, device=device)
    # create the PPO agent
    agent = PPO(envs, device=device)
    # set the intrinsic reward module
    agent.set(reward=irs)
    # train the agent
    agent.train(n_steps * 153, eval_interval=n_steps // 8, save_interval=n_steps)

I receive the following error:

/opt/conda/lib/python3.10/site-packages/gym/envs/registration.py:555: UserWarning: WARN: The environment SuperMarioBros-1-1-v0 is out of date. You should consider upgrading to version `v3`.
  logger.warn(
/opt/conda/lib/python3.10/site-packages/gym/envs/registration.py:627: UserWarning: WARN: The environment creator metadata doesn't include `render_modes`, contains: ['render.modes', 'video.frames_per_second']
  logger.warn(
/opt/conda/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.metadata to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.metadata` for environment variables or `env.get_wrapper_attr('metadata')` that will search the reminding wrappers.
  logger.warn(
/opt/conda/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.single_observation_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_observation_space` for environment variables or `env.get_wrapper_attr('single_observation_space')` that will search the reminding wrappers.
  logger.warn(
/opt/conda/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.single_action_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_action_space` for environment variables or `env.get_wrapper_attr('single_action_space')` that will search the reminding wrappers.
  logger.warn(
cuda Box(0, 255, (4, 84, 84), uint8) Discrete(7)
/opt/conda/lib/python3.10/site-packages/gym/utils/passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)
  if not isinstance(terminated, (bool, np.bool8)):
[05/24/2024 04:14:52 PM] - [INFO.] - Invoking RLLTE Engine...
[05/24/2024 04:14:52 PM] - [INFO.] - ================================================================================
[05/24/2024 04:14:52 PM] - [INFO.] - Tag               : default
[05/24/2024 04:14:52 PM] - [INFO.] - Device            : NVIDIA A100-SXM4-40GB
[05/24/2024 04:14:52 PM] - [DEBUG] - Agent             : PPO
[05/24/2024 04:14:52 PM] - [DEBUG] - Encoder           : MnihCnnEncoder
[05/24/2024 04:14:52 PM] - [DEBUG] - Policy            : OnPolicySharedActorCritic
[05/24/2024 04:14:52 PM] - [DEBUG] - Storage           : VanillaRolloutStorage
[05/24/2024 04:14:52 PM] - [DEBUG] - Distribution      : Categorical
[05/24/2024 04:14:52 PM] - [DEBUG] - Augmentation      : None
[05/24/2024 04:14:52 PM] - [DEBUG] - Intrinsic Reward  : RND
[05/24/2024 04:14:52 PM] - [DEBUG] - ================================================================================
Traceback (most recent call last):
  File "/workdir/got-it-memorized/src/run_rnd2.py", line 20, in <module>
    agent.train(n_steps * 153, eval_interval=n_steps // 8, save_interval=n_steps)
  File "/opt/conda/lib/python3.10/site-packages/rllte/common/prototype/on_policy_agent.py", line 105, in train
    obs, infos = self.env.reset(seed=self.seed)
  File "/opt/conda/lib/python3.10/site-packages/rllte/env/utils.py", line 152, in reset
    obs, infos = self.env.reset(seed=seed, options=options)
  File "/opt/conda/lib/python3.10/site-packages/gymnasium/wrappers/record_episode_statistics.py", line 78, in reset
    obs, info = super().reset(**kwargs)
  File "/opt/conda/lib/python3.10/site-packages/gymnasium/core.py", line 467, in reset
    return self.env.reset(seed=seed, options=options)
  File "/opt/conda/lib/python3.10/site-packages/gymnasium/vector/vector_env.py", line 140, in reset
    return self.reset_wait(seed=seed, options=options)
  File "/opt/conda/lib/python3.10/site-packages/gymnasium/vector/sync_vector_env.py", line 122, in reset_wait
    observation, info = env.reset(**kwargs)
ValueError: too many values to unpack (expected 2)
yuanmingqi commented 4 months ago

Sorry, I forgot to update the rllte-core package, you can try to install the new version by:

pip install rllte-core --upgrade
edofazza commented 4 months ago

I upgraded the package, but I still face the same issue. It works when I use RIDE, but with RND and Disagreement the error is still present

yuanmingqi commented 4 months ago

it looks like a problem from gymnasium, can you provide the version number of your packages?

gymnasium and gym

edofazza commented 4 months ago

Here:

absl-py                    2.1.0
attrs                      23.2.0
AutoROM                    0.4.2
AutoROM.accept-rom-license 0.6.1
bitmath                    1.3.3.1
Brotli                     1.0.9
certifi                    2024.2.2
cffi                       1.16.0
chardet                    4.0.0
charset-normalizer         2.0.4
click                      7.1.2
cloudpickle                3.0.0
configobj                  5.0.8
contourpy                  1.2.0
cryptography               41.0.7
cycler                     0.12.1
decorator                  4.4.2
deprecation                2.1.0
dill                       0.3.8
docker-pycreds             0.4.0
dulwich                    0.22.1
easydict                   1.9
enum-tools                 0.12.0
everett                    3.1.0
Farama-Notifications       0.0.4
filelock                   3.13.1
Flask                      1.1.4
fonttools                  4.47.2
fsspec                     2023.10.0
gitdb                      4.0.11
GitPython                  3.1.43
gmpy2                      2.1.2
graphviz                   0.20.3
grpcio                     1.64.0
gym                        0.26.2
gym-notices                0.0.8
gym-super-mario-bros       7.4.0
gymnasium                  0.29.1
h5py                       3.11.0
hbutils                    0.9.3
hickle                     5.0.3
huggingface-hub            0.14.1
idna                       3.4
imageio                    2.34.1
imageio-ffmpeg             0.4.9
itsdangerous               1.1.0
Jinja2                     2.11.3
joblib                     1.4.2
jsonschema                 4.22.0
jsonschema-specifications  2023.12.1
kiwisolver                 1.4.5
lz4                        4.3.3
Markdown                   3.6
markdown-it-py             3.0.0
MarkupSafe                 2.0.1
matplotlib                 3.6.0
mdurl                      0.1.2
moviepy                    1.0.3
mpire                      2.10.2
mpmath                     1.3.0
nes-py                     8.2.1
networkx                   3.1
numpy                      1.26.3
opencv-python              4.9.0.80
opencv-python-headless     4.9.0.80
packaging                  23.2
pandas                     2.1.4
Pillow                     10.0.1
pip                        23.3.1
platformdirs               4.2.2
proglog                    0.1.10
protobuf                   3.20.3
psutil                     5.9.8
pycparser                  2.21
pyglet                     1.5.21
Pygments                   2.18.0
pynng                      0.8.0
pynvml                     11.5.0
pyOpenSSL                  23.2.0
pyparsing                  3.1.1
PySocks                    1.7.1
python-box                 6.1.0
python-dateutil            2.8.2
pytimeparse                1.1.8
pytz                       2023.3.post1
PyYAML                     6.0.1
redis                      5.0.4
referencing                0.35.1
requests                   2.31.0
requests-toolbelt          1.0.0
responses                  0.12.1
rich                       13.7.1
rllte-core                 1.0.0
rpds-py                    0.18.1
sb3-contrib                2.2.1
scikit-learn               1.5.0
scipy                      1.13.0
seaborn                    0.12.2
semantic-version           2.10.0
sentry-sdk                 2.2.0
setproctitle               1.3.3
setuptools                 66.1.1
Shimmy                     1.3.0
simplejson                 3.19.2
six                        1.16.0
smmap                      5.0.1
sniffio                    1.3.1
stable-baselines3          2.2.1
sympy                      1.12
tabulate                   0.9.0
tensorboard                2.16.2
tensorboard-data-server    0.7.2
tensorboardX               2.2
termcolor                  2.4.0
threadpoolctl              3.5.0
torch                      2.1.0.post100
torch-summary              1.4.5
torchvision                0.15.2a0
tqdm                       4.66.1
treevalue                  1.4.12
trueskill                  0.4.5
typing_extensions          4.7.1
tzdata                     2023.4
urllib3                    1.26.18
URLObject                  2.4.3
wandb                      0.17.0
websocket-client           1.8.0
Werkzeug                   1.0.1
wheel                      0.41.2
wrapt                      1.16.0
wurlitzer                  3.1.0
yapf                       0.29.0
yattag                     1.15.2
yuanmingqi commented 4 months ago

okay, wait for a while, I will check now.

yuanmingqi commented 4 months ago

It's a problem of the SuperMarioBros env, it can work with Atari env.

Could you please use other envs first? We'll fix this asap.

edofazza commented 4 months ago

I tried with Atari env now and it works, but I am interested in the SMB env for a research I am doing

yuanmingqi commented 4 months ago

the issue has been fixed, please run pip install rllte-core --upgrade

edofazza commented 4 months ago

Now it works, perfect! Thank you very much!