Stable-Baselines-Team / stable-baselines3-contrib

Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code
https://sb3-contrib.readthedocs.io
MIT License
442 stars 166 forks source link

[Question] What is the difference between old_distribution and distribution in train function of TRPO #240

Closed 0Addicted0 closed 2 months ago

0Addicted0 commented 2 months ago

❓ Question

🙏Thanks for the high scalability made by sb3-contrib

I am referring to the MaskablePPO method to add a mask to TRPO. And In the train function of I have found the following code:

    with th.no_grad():
        # Note: is copy enough, no need for deepcopy?
        # If using gSDE and deepcopy, we need to use `old_distribution.distribution`
        # directly to avoid PyTorch errors.
        old_distribution = copy.copy(self.policy.get_distribution(rollout_data.observations))

    distribution = self.policy.get_distribution(rollout_data.observations)
    log_prob = distribution.log_prob(actions)

    advantages = rollout_data.advantages
    if self.normalize_advantage:
        advantages = (advantages - advantages.mean()) / (rollout_data.advantages.std() + 1e-8)

    # ratio between old and new policy, should be one at the first iteration
    ratio = th.exp(log_prob - rollout_data.old_log_prob)

    # surrogate policy objective
    policy_objective = (advantages * ratio).mean()

    # KL divergence
    kl_div = kl_divergence(distribution, old_distribution).mean()

❓Does it look like old_distribution and distribution are exactly the same(kl_div here eqs 0), or did I misread something?

By the way, may I also ask if adding action_masks forTRPO requires providing the corresponding masks before calculating the distribution used for kl_div?

🙂Thanks a lot

Checklist

araffin commented 2 months ago

https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/25b43266e08ebe258061ac69688d94144799de75/sb3_contrib/trpo/trpo.py#L282

not sure for the second question but probably yes

0Addicted0 commented 2 months ago

Thank you for your timely response.

I think I misunderstood the meaning of line 246

stable-baselines3-contrib/sb3_contrib/trpo/trpo.py

For the second question, just providing action masks as you said, at least it runs well in my custom environment

🙂Thank you very much