microsoft / LMOps

General technology for enabling AI capabilities w/ LLMs and MLLMs
https://aka.ms/GeneralAI
MIT License
3.71k stars 283 forks source link

[MiniLLM] mismatch between formula and implementation (gradL)_long? [old_logprobs not find in paper] #265

Closed lancerts closed 2 months ago

lancerts commented 2 months ago

In paper Screenshot 2024-09-10 at 4 22 08 PM, in code

  def _pg_loss(
      self,
      logprobs: TensorType["batch_size", "response_size"],
      old_logprobs: TensorType["batch_size", "response_size"],
      advantages: TensorType["batch_size", "response_size"],
      mask: TensorType["batch_size", "response_size"],
      w: TensorType["batch_size", "response_size"],
  ):
      """PPO objective function.
      References:
      - https://stable-baselines.readthedocs.io/en/master/modules/ppo2.html
      """
      n = mask.sum()

      log_ratio = (logprobs - old_logprobs) * mask
      ratio = torch.exp(log_ratio.float())            
      ratio = ratio * w

Is old_logprobs referring to the teacher-mixed sampling $\tilde{p}$ or its something else?

t1101675 commented 2 months ago

old_logprobs refers to $q{\theta}$, which will not be taken derivatives. logprobs is $q{\theta}$ that will be taken derivatives. Computing gradients for ratio refers to compute $\frac{\nabla q{\theta}}{q{\theta}}$. Therefore, ratio * w refers to $\rhot(\theta) = \frac{q{\theta}}{\widetilde{p}}$, where only the $\theta$ in the numerator will be taken gradients.

lancerts commented 2 months ago

Thanks for the detailed explanation.