nikhilbarhate99 / PPO-PyTorch

Minimal implementation of clipped objective Proximal Policy Optimization (PPO) in PyTorch
MIT License
1.63k stars 340 forks source link

how did you figure out continuous? #3

Closed nyck33 closed 5 years ago

nyck33 commented 5 years ago

I can see that your Actor network has tanh activation on the output layer but then I am totally lost as to what you do here:

def act(self, state, memory):
        action_mean = self.actor(state)
        dist = MultivariateNormal(action_mean, torch.diag(self.action_var).to(device))
        action = dist.sample()
        action_logprob = dist.log_prob(action)

Especially action_mean = self.actor(state). Does this mean you have one output node and assume that the output is the mean of a Gaussian distribution over the action space?

Then similar code appears here:

def evaluate(self, state, action):
        action_mean = self.actor(state)
        dist = MultivariateNormal(torch.squeeze(action_mean), torch.diag(self.action_var))

        action_logprobs = dist.log_prob(torch.squeeze(action))
        dist_entropy = dist.entropy()
        state_value = self.critic(state)

        return action_logprobs, torch.squeeze(state_value), dist_entropy

Also is self.policy like a dummy Actor Critic network that you use just to get updated parameters to load in to self.policy_old? I know this isn't stackoverflow but if you can look at my implementation and let me know how I can adapt it for a continuous action space, that'd be great.

my PPO discrete action space

nikhilbarhate99 commented 5 years ago

Hey, the action distribution is assumed to be a multivariate normal distribution with a diagonal covariance matrix. So, the last layer outputs the means of all variables in action space (i.e. mean vector) and the covariance matrix is just the diagonal matrix of the square of fixed standard deviation (hyper parameter : action_std). From the mean vector and the covariance matrix we can construct a multivariate normal distribution using standard PyTorch function. Regarding self.policy, Since we update the self.policy for k_epochs (i.e. k times) in one PPO update, we keep the self.policy_old as a copy of old network weights to compute ratios. I think you should refer to the original PPO paper for more detail.