RLE-Foundation / rllte

Long-Term Evolution Project of Reinforcement Learning
https://docs.rllte.dev/
MIT License
453 stars 84 forks source link

[Bug]: Impossible to load model to use it for training #49

Open edofazza opened 2 months ago

edofazza commented 2 months ago

🐛 Bug

I'm trying to load a trained model to use it for testing, but I am facing with an error. Thank you.

To Reproduce

import torch as th
import os
from rllte.xplore.reward import RND, Disagreement, RIDE
from rllte.env import make_mario_env
from rllte.agent import PPO, DDPG

if __name__ == '__main__':
    n_steps: int = 2048 * 16
    device = 'cuda' if th.cuda.is_available() else 'cpu'
    envs = make_mario_env('SuperMarioBros-1-1-v0', device=device, num_envs=1,
                          asynchronous=False, frame_stack=4, gray_scale=True)
    print(device, envs.observation_space, envs.action_space)
    # create the intrinsic reward module
    #irs = Disagreement(envs, device=device)
    # create the PPO agent
    agent = PPO(envs,
                device=device,
                batch_size=512,
                n_epochs=10,
                num_steps=n_steps//8,
                pretraining=True)
    agent.policy.load_state_dict(th.load("ride_1_1_1507328.pth", map_location=th.device('cpu')),)
    agent.eval(100)

Relevant log output / Error message

/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/gym/envs/registration.py:555: UserWarning: WARN: The environment SuperMarioBros-1-1-v0 is out of date. You should consider upgrading to version `v3`.
  logger.warn(
/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/gym/envs/registration.py:627: UserWarning: WARN: The environment creator metadata doesn't include `render_modes`, contains: ['render.modes', 'video.frames_per_second']
  logger.warn(
/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.metadata to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.metadata` for environment variables or `env.get_wrapper_attr('metadata')` that will search the reminding wrappers.
  logger.warn(
/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.single_observation_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_observation_space` for environment variables or `env.get_wrapper_attr('single_observation_space')` that will search the reminding wrappers.
  logger.warn(
/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.single_action_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_action_space` for environment variables or `env.get_wrapper_attr('single_action_space')` that will search the reminding wrappers.
  logger.warn(
cpu Box(0, 255, (4, 84, 84), uint8) Discrete(7)
Traceback (most recent call last):
  File "/Users/edoardofazzari/Documents/GitHub/got-it-memorized/src/tests.py", line 22, in <module>
    agent.policy.load_state_dict(th.load("/Users/edoardofazzari/Documents/GitHub/got-it-memorized/src/ride_1_1_1507328.pth",
  File "/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2103, in load_state_dict
    raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.")
TypeError: Expected state_dict to be dict-like, got <class 'rllte.common.utils.ExportModel'>.

System Info

No response

Checklist

yuanmingqi commented 2 months ago

Please follow: https://docs.rllte.dev/tutorials/mt/quick_start/