Closed yongduek closed 3 years ago
The following error occurred from running $ python 02_dqn_pong.py:
$ python 02_dqn_pong.py
... 9290: done 10 games, reward -20.200, eps 0.94, speed 651.18 f/s Traceback (most recent call last): File ".\02_dqn_pong.py", line 189, in <module> loss_t = calc_loss(batch, net, tgt_net, device=device) File ".\02_dqn_pong.py", line 106, in calc_loss state_action_values = net(states_v).gather( RuntimeError: gather_out_cpu(): Expected dtype int64 for index
So I changed as follows (see the line for actions_v), and the error is gone.
actions_v
def calc_loss(batch, net, tgt_net, device="cpu"): states, actions, rewards, dones, next_states = batch states_v = torch.tensor(np.array(states, copy=False)).to(device) next_states_v = torch.tensor(np.array(next_states, copy=False)).to(device) actions_v = torch.tensor(actions, dtype=torch.int64).to(device) # action index must be int64 rewards_v = torch.tensor(rewards).to(device) done_mask = torch.BoolTensor(dones).to(device) state_action_values = net(states_v).gather(1, actions_v.unsqueeze(-1)).squeeze(-1) with torch.no_grad(): next_state_values = tgt_net(next_states_v).max(1)[0] next_state_values[done_mask] = 0.0 next_state_values = next_state_values.detach() expected_state_action_values = next_state_values * GAMMA + \ rewards_v return nn.MSELoss()(state_action_values, expected_state_action_values)
Just in case you come to have the same error.
Hi @yongduek
Thanks for providing the solution.
Regards,
The following error occurred from running
$ python 02_dqn_pong.py
:So I changed as follows (see the line for
actions_v
), and the error is gone.Just in case you come to have the same error.