I'm trying to load a trained model to use it for testing, but I am facing with an error.
Thank you.
To Reproduce
import torch as th
import os
from rllte.xplore.reward import RND, Disagreement, RIDE
from rllte.env import make_mario_env
from rllte.agent import PPO, DDPG
if __name__ == '__main__':
n_steps: int = 2048 * 16
device = 'cuda' if th.cuda.is_available() else 'cpu'
envs = make_mario_env('SuperMarioBros-1-1-v0', device=device, num_envs=1,
asynchronous=False, frame_stack=4, gray_scale=True)
print(device, envs.observation_space, envs.action_space)
# create the intrinsic reward module
#irs = Disagreement(envs, device=device)
# create the PPO agent
agent = PPO(envs,
device=device,
batch_size=512,
n_epochs=10,
num_steps=n_steps//8,
pretraining=True)
agent.policy.load_state_dict(th.load("ride_1_1_1507328.pth", map_location=th.device('cpu')),)
agent.eval(100)
Relevant log output / Error message
/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/gym/envs/registration.py:555: UserWarning: WARN: The environment SuperMarioBros-1-1-v0 is out of date. You should consider upgrading to version `v3`.
logger.warn(
/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/gym/envs/registration.py:627: UserWarning: WARN: The environment creator metadata doesn't include `render_modes`, contains: ['render.modes', 'video.frames_per_second']
logger.warn(
/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.metadata to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.metadata` for environment variables or `env.get_wrapper_attr('metadata')` that will search the reminding wrappers.
logger.warn(
/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.single_observation_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_observation_space` for environment variables or `env.get_wrapper_attr('single_observation_space')` that will search the reminding wrappers.
logger.warn(
/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.single_action_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_action_space` for environment variables or `env.get_wrapper_attr('single_action_space')` that will search the reminding wrappers.
logger.warn(
cpu Box(0, 255, (4, 84, 84), uint8) Discrete(7)
Traceback (most recent call last):
File "/Users/edoardofazzari/Documents/GitHub/got-it-memorized/src/tests.py", line 22, in <module>
agent.policy.load_state_dict(th.load("/Users/edoardofazzari/Documents/GitHub/got-it-memorized/src/ride_1_1_1507328.pth",
File "/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2103, in load_state_dict
raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.")
TypeError: Expected state_dict to be dict-like, got <class 'rllte.common.utils.ExportModel'>.
System Info
No response
Checklist
[X] I have checked that there is no similar issue in the repo
🐛 Bug
I'm trying to load a trained model to use it for testing, but I am facing with an error. Thank you.
To Reproduce
Relevant log output / Error message
System Info
No response
Checklist