rll / rllab

rllab is a framework for developing and evaluating reinforcement learning algorithms, fully compatible with OpenAI Gym.
Other
2.91k stars 799 forks source link

Why is a MeanKLBefore not equal to zero in TRPO? #126

Open rarilurelo opened 7 years ago

rarilurelo commented 7 years ago

Hi! MeanKLBefore is defined at optimize_policy in npo.py

    def optimize_policy(self, itr, samples_data):
        all_input_values = tuple(ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        ))
        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
        all_input_values += tuple(state_info_list) + tuple(dist_info_list)
        if self.policy.recurrent:
            all_input_values += (samples_data["valids"],)
        loss_before = self.optimizer.loss(all_input_values)
        mean_kl_before = self.optimizer.constraint_val(all_input_values)
        self.optimizer.optimize(all_input_values)
        mean_kl = self.optimizer.constraint_val(all_input_values)
        loss_after = self.optimizer.loss(all_input_values)
        logger.record_tabular('LossBefore', loss_before)
        logger.record_tabular('LossAfter', loss_after)
        logger.record_tabular('MeanKLBefore', mean_kl_before)
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('dLoss', loss_before - loss_after)
        return dict()

I think the policy1 which samples action for collecting trajectory is strictly equal to the policy2 for computing KL, so KL(policy1||policy2) (MeanKLBefore) should be equal to zero. However, it has slightly value. To confirm difference between policy1 and policy2 I ran a example code trpo_gym.py and inserted print debug.

    def optimize_policy(self, itr, samples_data):
        all_input_values = tuple(ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        ))
        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
        all_input_values += tuple(state_info_list) + tuple(dist_info_list)
        if self.policy.recurrent:
            all_input_values += (samples_data["valids"],)

        means1 = agent_infos['mean']
        log_std1 = agent_infos['log_std']
        observations = all_input_values[0]
        _, d = self.policy.get_actions(observations)
        means2 = d['mean']
        log_std2 = d['log_std']
        print('Is mean different?', (means1-means2).any())
        print('Is log_std different?', (log_std1-log_std2).any())

        loss_before = self.optimizer.loss(all_input_values)
        mean_kl_before = self.optimizer.constraint_val(all_input_values)
        self.optimizer.optimize(all_input_values)
        mean_kl = self.optimizer.constraint_val(all_input_values)
        loss_after = self.optimizer.loss(all_input_values)
        logger.record_tabular('LossBefore', loss_before)
        logger.record_tabular('LossAfter', loss_after)
        logger.record_tabular('MeanKLBefore', mean_kl_before)
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('dLoss', loss_before - loss_after)
        return dict()

The result

Is mean different? True
Is log_std different? False

The result represents that MeanKLBefore is not equal to zero because of the difference of mean. My question is what causes the difference of mean?

Thanks for your help!

dementrock commented 7 years ago

Depending on how large the difference is. If it's e.g. around 1e-8 it might be just numerical precision issue.

rarilurelo commented 7 years ago

If the KL was equal to zero, that is numerical precision error didn't occur, a differential of the KL with respect to parameters would be zero. And hessian vector products also would be zero. Does this implementation depend on numerical precision issue? It sounds strange for me.

dementrock commented 7 years ago
  1. The gradient is taken w.r.t. the surrogate loss, not the KL. The hessian vector product is computed w.r.t. KL. It should not be too sensitive about numerical precision.
  2. There's no reason for dKL/dparams to be zero. What MeanKLBefore logs is just a sanity check since the policy hasn't changed it. Sometimes when it's significantly nonzero it indicates a bug somewhere.