nikhilbarhate99 / PPO-PyTorch

Minimal implementation of clipped objective Proximal Policy Optimization (PPO) in PyTorch
MIT License
1.57k stars 332 forks source link

Question regarding state_values.detach() #29

Closed junkwhinger closed 4 years ago

junkwhinger commented 4 years ago

Hi, thanks for the great implementation. I learned a lot about PPO by reading your code.

I have one question regarding the state_values.detach() when updating PPO.

# Finding Surrogate Loss:
advantages = rewards - state_values.detach()

When you detach a tensor, it loses its computation record that is used in back propagation. So I checked if the weights of the value layer of the policy get updated, and they did not. Surprisingly, in my own experiment, the training performance was better with .detach() than the one without. But I still find it difficult to understand the use of detach() theoretically.

Thank you.

nikhilbarhate99 commented 4 years ago

state_values.detach() does not detach the tensor state_values from the computational graph, rather it returns a new detached tensor. For calculating the advantages we do not require to backpropagate through the value network since advantages are only calculated to update the policy network, hence we use detach().

For updating the value network, we later calculate the MSE loss without detach(). So theoretically, it makes sense.

As to why it is performing better? , I do not know.

junkwhinger commented 4 years ago

Hi, thank you for the reply. I must have made some mistake in my experiment. As you said '.detach' does return a new tensor without the computation graph, so it wouldn't stop the value layer gets updated with the undetached state values. My mistake!

fatalfeel commented 4 years ago

thank u for ur example~~~ i modify as this

critic_vpi = self.policy_next.network_critic(curr_states) critic_vpi = torch.squeeze(critic_vpi) qsa_sub_vs = rewards - critic_vpi.detach() # A(s,a) => Q(s,a) - V(s), V(s) is critic advantages = (qsa_sub_vs - qsa_sub_vs.mean()) / (qsa_sub_vs.std() + 1e-5)

    # Optimize policy for K epochs:
    for _ in range(self.train_epochs):
        #cstate_value is V(s) in A3C theroy. critic network weights as an actor feed state out reward value
        critic_actlogprobs, next_critic_values, entropy = self.policy_next.calculation(curr_states, curr_actions)
        #critic_actlogprobs, entropy = self.policy_next.calculation(curr_states, curr_actions)

        # Finding the ratio (pi_theta / pi_theta__old):
        # log(critic) - log(curraccu) = log(critic/curraccu)
        # ratios = e^log(critic/curraccu)
        ratios  = torch.exp(critic_actlogprobs - curr_logprobs.detach())

        #advantages = curr_stdscore - cstate_value.detach()
        surr1   = ratios * advantages
        surr2   = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages

        # mseLoss is Mean Square Error = (target - output)^2
        loss    = -torch.min(surr1, surr2) + 0.5*self.mseLoss(rewards, next_critic_values) - 0.01*entropy

        # take gradient step
        self.optimizer.zero_grad()
        loss.mean().backward()  #get grade.data
        self.optimizer.step()   #update grade.data by adam method which is smooth grade

refer to here: https://github.com/ASzot/ppo-pytorch/blob/master/train.py