dxyang / DQN_pytorch

Vanilla DQN, Double DQN, and Dueling DQN implemented in PyTorch
428 stars 94 forks source link

Fixed unsqueeze error #3

Open bgreenawald opened 6 years ago

bgreenawald commented 6 years ago

A fix to the issue: https://github.com/dxyang/DQN_pytorch/issues/2. Using a newer version of Pytorch led to the following error: RuntimeError: invalid argument 3: Index tensor must have same dimensions as input tensor at /pytorch/torch/lib/THC/generic/THCTensorScatterGather.cu:199

This checks the version of Pytorch being used. If '0.2.0' is used, nothing changes. If a newer version is used, line 229 in "learn.py" is changed from q_s_a.backward(clipped_error.data.unsqueeze(1)) to q_s_a.backward(clipped_error.data) which solves the error.