werner-duvaud / muzero-general

MuZero
https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation
MIT License
2.46k stars 606 forks source link

Policy target after MCTS should be in form of probabilities #193

Open 2M-kotb opened 2 years ago

2M-kotb commented 2 years ago

This issue appears only in the implementation of continuous actions version of MuZero.

When computing child visits, we need to divide by sum_visits in order to be in probabilities form.

But, it seems like you forget to divide by sum_visits. Here is the current implementation


sum_visits = sum(child.visit_count for child in root.children.values())

self.child_visits.append(
                numpy.array([child.visit_count  for child in root.children.values()])
            )

I think the correct is the following:

sum_visits = sum(child.visit_count for child in root.children.values())

 self.child_visits.append(
                numpy.array([child.visit_count / sum_visits for child in root.children.values()])
            )
FYQ0919 commented 2 years ago

Yes, bro, I also noticed this problem, which causes the KL loss to be less than 0.

The target_policy_action is calculated by action value, and the current calculation of KL loss is not for two distributions.

log_prob = dist.log_prob(target_policy_action[:, i, :]).sum(1) policy_loss += torch.exp(log_prob) * (log_prob - log_target)