Closed oppure closed 3 months ago
Thanks for raising this. You can simply solve this issue by doing:
device = "cuda" if torch.cuda.is_available() else "cpu"
rainbow_dqn = RainbowDQN.load("RainbowDQN_0_200.pt", device=device)
However, I have updated the framework to be more robust in the future with loading agents to devices, and you can get these updates by updating to the latest version of agilerl.
PS. I noticed in your evaluation/rendering script that when doing agent.getAction, you need to add the flag training=False in order to exploit the learned policy.
action, *_ = rainbow_dqn.getAction(state, training=False)
I have updated the tutorial to reflect this too.
Thanks for using AgileRL!
What version of AgileRL are you using? 0.1.27 What operating system and processor architecture are you using? windows10 x64
Steps to reproduce the behaviour: Create a .py file and copy/paste code from https://docs.agilerl.com/en/latest/tutorials/gymnasium/agilerl_rainbow_dqn_tutorial.html paragraphs Dependencies, Create the Environment, Instantiate an Agent, Experience Replay, Training and Saving an Agent - Using AgileRL train_off_policy function Execute the code
Create another .py file and copy paste the code from paragraphs Load agent, Test loop for inference, adding the required imports on top and INIT_HP definition. This is the code:
What did you expect to see? --------------- Episode: 0 --------------- Episodic Reward: 103.0 --------------- Episode: 1 --------------- Episodic Reward: 107.0 .......
What did you see instead? Describe the bug.
───────────────────── Traceback (most recent call last) ─────────────────────┐ │ C:...\miniconda3\envs\agilerl\Lib\site-packages\spyder_kernels\p │ │ y3compat.py:356 in compat_exec │ │ │ │ 353 │ │ 354 def compat_exec(code, globals, locals): │ │ 355 │ # Wrap exec in a function │ │ > 356 │ exec(code, globals, locals) │ │ 357 │ │ 358 │ │ 359 if name == 'main': │ │ │ │ d:...\python\agile_rl\cartpole\tes │ │ trender.py:66 in │
│ │
│ 63 │ │ │ │ state = np.moveaxis(state, [-1], [-3]) │
│ 64 │ │ │ │
│ 65 │ │ │ # Get next action from agent │
│ > 66 │ │ │ action, * = rainbow_dqn.getAction(state) │
│ 67 │ │ │ │
│ 68 │ │ │ # Save the frame for this step and append to frames list │
│ 69 │ │ │ frame = test_env.render() │
│ │
│ C:....\miniconda3\envs\agilerl\Lib\site-packages\agilerl\algorith │
│ ms\dqn_rainbow.py:287 in getAction │
│ │
│ 284 │ │ │
│ 285 │ │ self.actor.train(mode=training) │
│ 286 │ │ with torch.no_grad(): │
│ > 287 │ │ │ action_values = self.actor(state) │
│ 288 │ │ │
│ 289 │ │ if action_mask is None: │
│ 290 │ │ │ action = np.argmax(action_values.cpu().data.numpy(), axis │
│ │
│ C:....\miniconda3\envs\agilerl\Lib\site-packages\torch\nn\modules │
│ \module.py:1532 in _wrapped_call_impl │
│ │
│ 1529 │ │ if self._compiled_call_impl is not None: │
│ 1530 │ │ │ return self._compiled_call_impl(*args, kwargs) # type │
│ 1531 │ │ else: │
│ > 1532 │ │ │ return self._call_impl(*args, *kwargs) │
│ 1533 │ │
│ 1534 │ def _call_impl(self, args, kwargs): │
│ 1535 │ │ forward_call = (self._slow_forward if torch._C._gettracing │
│ │
│ C:....\miniconda3\envs\agilerl\Lib\site-packages\torch\nn\modules │
│ \module.py:1541 in _call_impl │
│ │
│ 1538 │ │ if not (self._backward_hooks or self._backward_pre_hooks or │
│ 1539 │ │ │ │ or _global_backward_pre_hooks or _global_backward_ho │
│ 1540 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hook │
│ > 1541 │ │ │ return forward_call(*args, *kwargs) │
│ 1542 │ │ │
│ 1543 │ │ try: │
│ 1544 │ │ │ result = None │
│ │
│ C:....\miniconda3\envs\agilerl\Lib\site-packages\agilerl\networks │
│ \evolvable_mlp.py:316 in forward │
│ │
│ 313 │ │ │ x = x.clamp(min=1e-3) │
│ 314 │ │ │ │
│ 315 │ │ │ if q: │
│ > 316 │ │ │ │ x = torch.sum(x self.support, dim=2) │
│ 317 │ │ │
│ 318 │ │ return x │
│ 319 │
└─────────────────────────────────────────────────────────────────────────────┘
RuntimeError: Expected all tensors to be on the same device, but found at least
two devices, cuda:0 and cpu!
Additional context miniconda, python 3.11.9, torch 2.3.0-cu118, CUDA11.8
Adding:
after
rainbow_dqn = RainbowDQN.load("RainbowDQN_0_200.pt")
fixes the problem