Closed xcodeburpx closed 4 years ago
Hello.
I had some problems regards DeepQNetwork implementation using Pytorch.
I ran the code showed in your youtube video. I've got this error code:
~/PROJECTS/PYTORCH_TUTORIAL/main_DQN_file.py in <module> 35 brain.store_transition(observation, action, reward, observation_, done) 36 ---> 37 brain.learn() 38 observation = observation_ 39 scores.append(score) ~/PROJECTS/PYTORCH_TUTORIAL/simple_DQN.py in learn(self) 123 print("Q_Target slice: ",q_target[batch_index,actions_random]) 124 q_target[batch_index, action_indices] = reward_batch + \ --> 125 self.gamma*T.max(q_next,dim=1)[0]*terminal_batch 126 127 self.epsilon = self.epsilon*self.eps_dec if self.epsilon > \ IndexError: The shape of the mask [64] at index 0 does not match the shape of the indexed tensor [64, 4] at index 1
This error shows if the action indices are calculated using dot operator.
When I use np.argmax function whole network works properly.
Have you encountered this type of problem?
I have encountered the same error, I don't quite understand this sentence
When I use np.argmax function whole network works properly.
Can you tell me how to modify the code to run?
Thank you!
Hello, guys! I found a solution to this problem, you need to change the line number of 124 from
q_target[batch_index, action_indices] = reward_batch + self.gamma*T.max(q_next,dim=1)[0]*terminal_batch
to:
q_target[action_batch] = reward_batch + \
self.GAMMA*T.max(q_next, dim=1)[0]*terminal_batch
I hope this could help.
Hello! I got the same index error as you guys I fixed it by giving each element one by one as following
target_update = reward_batch + \
self.gamma*T.max(q_next, dim=1)[0]*terminal_batch
for i in range(len(batch_index)):
q_target[batch_index[i], action_indices[i]] = target_update[i]
This should fix the error for other people too I hope. It worked for me.
The code still functions on my local machine. I'm scratching my head trying to find out where the issues are cropping up. I have no doubt you guys are having problems, but just posting a snippet isn't super helpful.
Can you guys post your version in a git and then link so I can view the code? There could be something subtle elsewhere that leads to an issue.
Using just the action batch will not work as you will end up with the wrong dimensions (you get batch_size x batch_size, I believe).
In hindsight, there is no reason to go to a 1 hot encoding and then back. It's needlessly complex and just introduces the potential for bugs. It's been so long since I've made the video that I can't remember my thought process behind it.
Hello, guys! I found a solution to this problem, you need to change the line number of 124 from
q_target[batch_index, action_indices] = reward_batch + self.gamma*T.max(q_next,dim=1)[0]*terminal_batch
to:
q_target[action_batch] = reward_batch + \ self.GAMMA*T.max(q_next, dim=1)[0]*terminal_batch
I hope this could help.
This will fix the dimensional mismatch, but gives the incorrect values for q_target. You can verify this by setting the batch size to something small (say 8), printing q_target before you index it with action_indices, and then setting q_target[action_indices] = dummy_value and printing q_target again. You will see that you don't get what you expect (or want).
In a strange turn of events, I hosed my Anaconda install while trying to install manimlib. I had to reinstall Anaconda, and after doing so I get the same error.
The error comes in because the data type of the action_indices is np.uint8. Combining the uint8 with the int32 of the batch_index causes the error. Switching the datatype of action_indices to np.int32 fixes the problem with the dimensional mismatch and yields the expected results when using the test I propose in the comment above.
Fixed code is up on the repo.
Hello.
I had some problems regards DeepQNetwork implementation using Pytorch.
I ran the code showed in your youtube video. I've got this error code:
This error shows if the action indices are calculated using dot operator.
When I use np.argmax function whole network works properly.
Have you encountered this type of problem?