pranz24 / pytorch-soft-actor-critic

PyTorch implementation of soft actor critic
MIT License
823 stars 182 forks source link

Derivative in reparametrization trick? #11

Closed ZeratuuLL closed 5 years ago

ZeratuuLL commented 5 years ago

Hi, I met a problem in understanding the log-likelihood again and hopefully you can help me!

in .sample() method in GuassianPolicy class, https://github.com/pranz24/pytorch-soft-actor-critic/blob/86412e14e1a94243b6ab4b7a96b092373f2bc1bc/model.py#L88 generates a new x_t, with gradients to the network, and line https://github.com/pranz24/pytorch-soft-actor-critic/blob/86412e14e1a94243b6ab4b7a96b092373f2bc1bc/model.py#L90 calculates the log-likelihood of x_t, but somehow I think this log-likelihood does not have gradient to the mean part but only std part. Am I correct?

Here is what I think: x_t = mean + noise*std, here mean and std are outputs from .forward() and noise is a generated N(0,1) with no gradient. Then when you are calculating log_prob = normal.log_prob(x_t), the function returns something like

-((value - self.loc) * 2) / (2 var) - log_scale - math.log(math.sqrt(2 * math.pi))

which is -0.5 - log_scale - math.log(math.sqrt(2 * math.pi)) so it only has gradient to std part but not mean part. Do you agree with me?

I am also trying to implement this algorithm but it seems that when I include the entropy term in the loss of policy, everything goes wrong. But without it everything is fine. I am looking into the problem so I am comparing all details between my implement and yours. Please help me if you can. Thank you very much!

pranz24 commented 5 years ago

I don't understand how -((value - self.loc) * 2) / (2 var) - log_scale - math.log(math.sqrt(2 math.pi)) = -0.5 - log_scale - math.log(math.sqrt(2 math.pi))

To relate more to
-((value - self.loc) 2) / (2 var) - log_scale - math.log(math.sqrt(2 math.pi)) -> log of normal function is basically -((x_t - mean) 2) / (2 * (std * 2)) - log_std - math.log(math.sqrt(2 math.pi))

Why is ((x_t - mean) * 2) / (2 (std ** 2)) = 0.5 ?

Removing entropy term makes SAC very similar to DDPG with a stochastic actor and without a target actor. It is understandable that it will work when you remove the entropy term from both critic and policy loss. It is weird that it works when you remove entropy term only from policy loss

ZeratuuLL commented 5 years ago

Hi, thank you for your reply!

Yeah I made a mistake. ((x_t - mean) * 2) / (2 (std ** 2)) is not 0.5 but it is some value that does not have anything to do with mean right? So this term still has no gradient in the mean part but only the std part. Am I correct about this?

And I think remove entropy from policy loss only is similar to 'maximize expected rewards + entropy of next step'. It does not explicitly force the policy to have larger entropy, but treats the entropy of the next step as a part of reward.

ZeratuuLL commented 5 years ago

Hi! Thank you for the reply. Actually the main point is this:

After x_t = normal.rsample() log_prob = normal.log_prob(x_t) What is the value of d(log_prob)/d(mean)? Personally I think it's 0.

And if you use log_prob = normal.log_prob(x_t.detach()) Things will be somehow different

I just can't figure out which one should I use....

ZeratuuLL commented 5 years ago

Here is what I think.

Since we are updating the policy network, call it pi. There are two parts in your code, one part decides the mean and the other part decides the std, let's call them pi_mean and pi_std here.

So if any loss function L should change pi_mean or pi_std, they must have non-zero gradient in the networks and this requires dL/d(mean) != 0, dL/d(std) != 0 since d(mean)/d(pi_std)=0, d(std)/d(mean)=0.

Now we consider the policy loss, the entropy part. -log(1-actions) contributes some gradient to both pi_mean and pi_std normal.log(x_t) contributes to only pi_std. Here from x_t we actually don't have access to mean. normal.log(x_t.detach()) contributes to both pi_mean and pi_std. From this way we do have access to mean and std from x_t. The thing to mention here is that normal.log() actually cares about (x_t-mean)=std*N(0,1), which is no longer related to mean.

So there is a difference. By practice normal.log(x_t) is better. And I can understand why by thinking of its effect on the mean.

I guess somehow I don't really understand the reparametrization trick mathematically..... BTW would you mind letting me know how you typed $\pi$ and $\nabla$ etc? I think I just cannot....

pranz24 commented 5 years ago

Oh Wow yes! (x_t-mean) = std*N(0,1) got it. Ok so you are right (and whatever I have written is absolute garbage :unamused: ) At least now I understand the problem. :expressionless: Now, I also think that you are right.

Long time ago I was using a scientific keyboard to write these expressions for a project and now I just copy it from there. (Not a very efficient way of doing it. :sweat_smile: )

I'll get back to you when I have an answer, thanks. :sweat_smile:

ZeratuuLL commented 5 years ago

I think my expression needs much improvement :(

Thank you very much! I didn't know the very existence of scientific keyboard.... Thank you for that as well!

pranz24 commented 5 years ago

The (negative)entropy of the normal distribution is log_pi = -0.5 (log(2π std^2) + 1) https://en.wikipedia.org/wiki/Differential_entropy http://users.ics.aalto.fi/ahonkela/dippa/node94.html So, we don't need to compute the gradient of mean. (If that is the answer you are looking for)

pranz24 commented 5 years ago

Oh.. damn!! Finally got time to revise from my notes and now I feel stupid. Super sorry I wasted one whole day on this issue. There was no problem in your question or expression, it was just that I was in a completely different mindset when asked the question. (And I don't know why I had an urge to answer :persevere:) I'll close the issue if I have your permission. (Also I deleted my old comments because they were embarrassing [hope you don't mind :disappointed:])

ZeratuuLL commented 5 years ago

Hi! I did not close this since I was still thinking about the reparametrization thing. Now perhaps I have an answer, which is that the noise and the new action, only one of them can be independent from the parameters.... Some how I guess it solves my problem.... Thank you for your help!