mimoralea / gdrl

Grokking Deep Reinforcement Learning
https://www.manning.com/books/grokking-deep-reinforcement-learning
BSD 3-Clause "New" or "Revised" License
798 stars 231 forks source link

dynamic tau parameter fixed in update_networks fn #14

Closed TroddenSpade closed 2 years ago

TroddenSpade commented 3 years ago

Dear @mimoralea, In DDPG.train function, after defining networks of actors and critics, both online and target networks' parameters should be equalized. As the self.update_networks(tau=1.0) suggests, assigning 1.0 to tau in update_networks should copy online parameters to target's. However, in the following function, a pre-defined self.tau is used as the weight of the online parameters.

def update_networks(self, tau=None):
        tau = self.tau if tau is None else tau
        for target, online in zip(self.target_value_model.parameters(), 
                                  self.online_value_model.parameters()):
    -->  target_ratio = (1.0 - self.tau) * target.data
    -->  online_ratio = self.tau * online.data
            mixed_weights = target_ratio + online_ratio
            target.data.copy_(mixed_weights)

        for target, online in zip(self.target_policy_model.parameters(), 
                                  self.online_policy_model.parameters()):
    -->  target_ratio = (1.0 - self.tau) * target.data
    -->  online_ratio = self.tau * online.data
            mixed_weights = target_ratio + online_ratio
            target.data.copy_(mixed_weights)

Instead, defined tau in the first line of the function should be used as weights in the subsequent calculations.

def update_networks(self, tau=None):
        tau = self.tau if tau is None else tau
        for target, online in zip(self.target_value_model.parameters(), 
                                  self.online_value_model.parameters()):
            target_ratio = (1.0 - tau) * target.data
            online_ratio = tau * online.data
            mixed_weights = target_ratio + online_ratio
            target.data.copy_(mixed_weights)

        for target, online in zip(self.target_policy_model.parameters(), 
                                  self.online_policy_model.parameters()):
            target_ratio = (1.0 - tau) * target.data
            online_ratio = tau * online.data
            mixed_weights = target_ratio + online_ratio
            target.data.copy_(mixed_weights)

These changes have been applied to DDPG class in chapter-12.ipynb

mimoralea commented 2 years ago

I definitely meant it the way you describe it.

Thank you for the fix! Will merge this weekend.

mimoralea commented 2 years ago

BTW, sorry for the delay, I'm not sure how I missed this.

mimoralea commented 2 years ago

Adding https://github.com/mimoralea/gdrl/pull/22 because there are multiple notebooks with this issue--thanks for reporting this and the pull request!