PacktPublishing / Deep-Reinforcement-Learning-Hands-On

Hands-on Deep Reinforcement Learning, published by Packt
MIT License
2.83k stars 1.29k forks source link

Chapter 6 02_dqn_pong.py RuntimeError #7

Closed knicholes closed 5 years ago

knicholes commented 5 years ago

I'm running this on Windows 10 (installed the atari lib with pip install -U git+https://github.com/Kojoley/atari-py.git) 9765: done 10 games, mean reward -20.100, eps 0.90, speed 1036.60 f/s Traceback (most recent call last): File "02_dqn_pong.py", line 170, in loss_t = calc_loss(batch, net, tgt_net, device=device) File "02_dqn_pong.py", line 97, in calc_loss state_action_values = net(states_v).gather(1, actions_v.unsqueeze(-1)).squeeze(-1) RuntimeError: Expected object of type torch.cuda.LongTensor but found type torch.cuda.IntTensor for argument #3 'index'

stevenapsel commented 5 years ago

I had the same issue. I fixed it on line 93 of "02_dqn_pong.py" in calc_loss by converting the type to long: actions_v = torch.tensor(actions, dtype=torch.int64).to(device)

knicholes commented 5 years ago

@stevenapsel Ah, you rock! That totally worked. Thank you.

Shmuma commented 5 years ago

Tried to reproduce this, but on linux machine with both torch 0.4.0 and 0.4.1 it works fine.