Closed laz8 closed 4 years ago
def compute_td_loss(batch_size): state, action, reward, next_state, done = replay_buffer.sample(batch_size) state = Variable(torch.FloatTensor(np.float32(state))) next_state = Variable(torch.FloatTensor(np.float32(next_state))) action = Variable(torch.LongTensor(action)) reward = Variable(torch.FloatTensor(reward)) done = Variable(torch.FloatTensor(done)) q_values = current_model(state) next_q_values = target_model(next_state) q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1) next_q_value = next_q_values.max(1)[0] expected_q_value = reward + gamma * next_q_value * (1 - done) loss = (q_value - expected_q_value.detach()).pow(2).mean() optimizer.zero_grad() loss.backward() optimizer.step() return loss
no forward call?
edit: sry, found my issue caused by Variable not "the missing forward", it works without calling forward(), the result is the same, can be closed.
no forward call?
edit: sry, found my issue caused by Variable not "the missing forward", it works without calling forward(), the result is the same, can be closed.