PacktPublishing / Deep-Reinforcement-Learning-Hands-On-Second-Edition

Deep-Reinforcement-Learning-Hands-On-Second-Edition, published by Packt
MIT License
1.13k stars 531 forks source link

Chapter06, RuntimeError: gather_out_cpu(): Expected dtype int64 for index #45

Closed yongduek closed 3 years ago

yongduek commented 3 years ago

The following error occurred from running $ python 02_dqn_pong.py:

...
9290: done 10 games, reward -20.200, eps 0.94, speed 651.18 f/s
Traceback (most recent call last):
  File ".\02_dqn_pong.py", line 189, in <module>
    loss_t = calc_loss(batch, net, tgt_net, device=device)
  File ".\02_dqn_pong.py", line 106, in calc_loss
    state_action_values = net(states_v).gather(
RuntimeError: gather_out_cpu(): Expected dtype int64 for index

So I changed as follows (see the line for actions_v), and the error is gone.

def calc_loss(batch, net, tgt_net, device="cpu"):
    states, actions, rewards, dones, next_states = batch

    states_v = torch.tensor(np.array(states, copy=False)).to(device)
    next_states_v = torch.tensor(np.array(next_states, copy=False)).to(device)
    actions_v = torch.tensor(actions, dtype=torch.int64).to(device)  # action index must be int64
    rewards_v = torch.tensor(rewards).to(device)
    done_mask = torch.BoolTensor(dones).to(device)

    state_action_values = net(states_v).gather(1, actions_v.unsqueeze(-1)).squeeze(-1)
    with torch.no_grad():
        next_state_values = tgt_net(next_states_v).max(1)[0]
        next_state_values[done_mask] = 0.0
        next_state_values = next_state_values.detach()

    expected_state_action_values = next_state_values * GAMMA + \
                                   rewards_v
    return nn.MSELoss()(state_action_values,
                        expected_state_action_values)

Just in case you come to have the same error.

sarvesh-packt commented 3 years ago

Hi @yongduek

Thanks for providing the solution.

Regards,