Closed Maxtoq closed 1 year ago
Thanks @Maxtoq for the interest!
If we can enumerate all $x \in \mathcal{X}$, then we can indeed compute:
$$ D{\text{KL}}(p|q) = \sum{x \in \mathcal{X}} p(x) \log[p(x)/q(x)] $$
as you mentioned. Note that this is an equality. There is no approximation.
When $\mathcal{X}$ is very large, enumerating all $x$ can be expensive. Instead we can approximate the KL by rewriting it as an expectation
$$ D{\text{KL}}(p|q) = \mathbb{E}{x \sim p(\cdot)} [\log p(x) - \log q(x)] $$
By drawing a number of samples $x_i \sim p(\cdot)$, we can approximate the expectation via an average of $\log p(x_i) - \log q(x_i)$
Okaaaay! Thank you I finally understand!
Hi there,
I have a question about the way you (and other implementations) compute the KL divergence penalty, by only taking the difference between log_probs and ref_log_probs: https://github.com/vwxyzjn/lm-human-preference-details/blob/ba1a240567366ed58e182dde897dd1523a44ffd0/lm_human_preference_details/train_policy_accelerate.py#L628
From my understanding, the KL divergence is defined as the sum of probabilities times the difference between log_probs:
Why is it that you don't compute it this way? I get the impression that this is linked to the fact that you compute the "per-token KL", but I don't really understand what this refers to. Could you explain this a bit, please?
Thank you!