vwxyzjn / lm-human-preference-details

RLHF implementation details of OAI's 2019 codebase
MIT License
152 stars 7 forks source link

Question about KL divergence computation #25

Closed Maxtoq closed 1 year ago

Maxtoq commented 1 year ago

Hi there,

I have a question about the way you (and other implementations) compute the KL divergence penalty, by only taking the difference between log_probs and ref_log_probs: https://github.com/vwxyzjn/lm-human-preference-details/blob/ba1a240567366ed58e182dde897dd1523a44ffd0/lm_human_preference_details/train_policy_accelerate.py#L628

From my understanding, the KL divergence is defined as the sum of probabilities times the difference between log_probs: image

Why is it that you don't compute it this way? I get the impression that this is linked to the fact that you compute the "per-token KL", but I don't really understand what this refers to. Could you explain this a bit, please?

Thank you!

vwxyzjn commented 1 year ago

See http://joschu.net/blog/kl-approx.html

liutianlin0121 commented 1 year ago

Thanks @Maxtoq for the interest!

If we can enumerate all $x \in \mathcal{X}$, then we can indeed compute:

$$ D{\text{KL}}(p|q) = \sum{x \in \mathcal{X}} p(x) \log[p(x)/q(x)] $$

as you mentioned. Note that this is an equality. There is no approximation.

When $\mathcal{X}$ is very large, enumerating all $x$ can be expensive. Instead we can approximate the KL by rewriting it as an expectation

$$ D{\text{KL}}(p|q) = \mathbb{E}{x \sim p(\cdot)} [\log p(x) - \log q(x)] $$

By drawing a number of samples $x_i \sim p(\cdot)$, we can approximate the expectation via an average of $\log p(x_i) - \log q(x_i)$

Maxtoq commented 1 year ago

Okaaaay! Thank you I finally understand!