abacusai / smaug

Apache License 2.0
56 stars 3 forks source link

Will you release the code? #2

Open zhanghaoie opened 7 months ago

TobiasLee commented 7 months ago

my implementation based on the DPOTrainer in trl for your reference:

    def dpo_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        reference_chosen_logps: torch.FloatTensor,
        reference_rejected_logps: torch.FloatTensor,
        reference_free: bool = False,
    ):
        pi_logratios = policy_chosen_logps - policy_rejected_logps
        ref_logratios = reference_chosen_logps - reference_rejected_logps

        if reference_free:
            ref_logratios = 0

        logits = pi_logratios - ref_logratios
        # add regularization here
        positive_reg = reference_chosen_logps - policy_chosen_logps

        losses = - ( F.logsigmoid(self.beta * logits) - self.dpop_lambda * torch.clamp(
            positive_reg, min=0 # lambda * max(0, ratio)
        ))
        chosen_rewards = (
            self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
        )
        rejected_rewards = (
            self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
        )

        return losses, chosen_rewards, rejected_rewards

Update ffter Discussion

Thanks for your discussions @tenggyut @LuJunru @FeiWang96 . A mistake is identified in my previous implementation and now I have corrected it.

FeiWang96 commented 6 months ago

my implementation based on the DPOTrainer in trl for your reference:

    def dpo_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        reference_chosen_logps: torch.FloatTensor,
        reference_rejected_logps: torch.FloatTensor,
        reference_free: bool = False,
    ):
        pi_logratios = policy_chosen_logps - policy_rejected_logps
        ref_logratios = reference_chosen_logps - reference_rejected_logps

        if reference_free:
            ref_logratios = 0

        logits = pi_logratios - ref_logratios
        # add regularization here
        positive_reg = reference_chosen_logps - policy_chosen_logps

        losses = -F.logsigmoid(self.beta * logits) - self.dpop_lambda * torch.clamp(
            positive_reg, min=0 # lambda * max(0, ratio)
        )
        chosen_rewards = (
            self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
        )
        rejected_rewards = (
            self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
        )

        return losses, chosen_rewards, rejected_rewards

Is it + instead of - according to the paper?

losses = -F.logsigmoid(self.beta * logits) + self.dpop_lambda * torch.clamp(positive_reg, min=0)
TobiasLee commented 6 months ago

notice that in the paper the ratio is defined as : $log \frac{p{win}}{p{ref}} $ here my positive_reg = reference_chosen_logps - policy_chosen_logps that would be $log \frac{p{ref}}{p{win}} $ so the sign is - to make it consistent with the paper

LuJunru commented 6 months ago

notice that in the paper the ratio is defined as : logpwinpref here my positive_reg = reference_chosen_logps - policy_chosen_logps that would be logprefpwin so the sign is - to make it consistent with the paper

losses = (
                -(F.logsigmoid(self.beta * logits) - self.dpop_lambda * torch.relu(reference_chosen_logps - policy_chosen_logps))
            )

This is my implementation. By the way, I think there are differences between equation 3 and further equations in the appendix.

tenggyut commented 6 months ago

notice that in the paper the ratio is defined as : logpwinpref here my positive_reg = reference_chosen_logps - policy_chosen_logps that would be logprefpwin so the sign is - to make it consistent with the paper

I think it should be +. self.dpop_lambda * torch.relu(reference_chosen_logps - policy_chosen_logps)) is always non-negative, how this can be a penalty term if it always make original loss unchanged or smaller?. besides the idea of the paper is to constrain the chosen response not far from the good response, that been said the reference_chosen_logps should not be smaller than policy_chosen_logps otherwise should be penaltiesed.

so I think it's a mistake in the paper. and if the lambda set to be 50, the loss will be way to high, I think it may be another mistake.

LuJunru commented 6 months ago

notice that in the paper the ratio is defined as : logpwinpref here my positive_reg = reference_chosen_logps - policy_chosen_logps that would be logprefpwin so the sign is - to make it consistent with the paper

I think it should be +. self.dpop_lambda * torch.relu(reference_chosen_logps - policy_chosen_logps)) is always non-negative, how this can be a penalty term if it always make original loss unchanged or smaller?. besides the idea of the paper is to constrain the chosen response not far from the good response, that been said the reference_chosen_logps should not be smaller than policy_chosen_logps otherwise should be penaltiesed.

so I think it's a mistake in the paper. and if the lambda set to be 50, the loss will be way to high, I think it may be another mistake.

I actually trained with my implementation, and I can share my observations just for references. Compared with what I reproduced with original DPO (results), I found the loss of dpop was dramatically large at the beginning, but then drop to a normal level through several steps. However, the final results of dpop was around 3% behind the results of dpo.

Update: I ran a new experiment with the official loss released by the author below (https://github.com/abacusai/smaug/issues/2#issuecomment-2075172150). The results of dpop were better than previous version, but still 1% behind the results of dpo.

tenggyut commented 6 months ago

notice that in the paper the ratio is defined as : logpwinpref here my positive_reg = reference_chosen_logps - policy_chosen_logps that would be logprefpwin so the sign is - to make it consistent with the paper

I think it should be +. self.dpop_lambda * torch.relu(reference_chosen_logps - policy_chosen_logps)) is always non-negative, how this can be a penalty term if it always make original loss unchanged or smaller?. besides the idea of the paper is to constrain the chosen response not far from the good response, that been said the reference_chosen_logps should not be smaller than policy_chosen_logps otherwise should be penaltiesed. so I think it's a mistake in the paper. and if the lambda set to be 50, the loss will be way to high, I think it may be another mistake.

I actually trained with my implementation, and I can share my observations just for references. Compared with what I reproduced with original DPO (results), I found the loss of dpop was dramatically large at the beginning, but then drop to a normal level through several steps. However, the final results of dpop was around 3% behind the results of dpo.

have you actually tried text generating with your trained dpop model? because if using -, I feel hard to understand......

LuJunru commented 6 months ago

notice that in the paper the ratio is defined as : logpwinpref here my positive_reg = reference_chosen_logps - policy_chosen_logps that would be logprefpwin so the sign is - to make it consistent with the paper

I think it should be +. self.dpop_lambda * torch.relu(reference_chosen_logps - policy_chosen_logps)) is always non-negative, how this can be a penalty term if it always make original loss unchanged or smaller?. besides the idea of the paper is to constrain the chosen response not far from the good response, that been said the reference_chosen_logps should not be smaller than policy_chosen_logps otherwise should be penaltiesed. so I think it's a mistake in the paper. and if the lambda set to be 50, the loss will be way to high, I think it may be another mistake.

I actually trained with my implementation, and I can share my observations just for references. Compared with what I reproduced with original DPO (results), I found the loss of dpop was dramatically large at the beginning, but then drop to a normal level through several steps. However, the final results of dpop was around 3% behind the results of dpo.

have you actually tried text generating with your trained dpop model?

I included 3 generative benchmarks: GSM8K, BBH and IFEval. By the way, the lambada I used was exactly 50.

tenggyut commented 6 months ago

moid(self.beta * logits)

my mistake,I'am not mean your impl。 it's @FeiWang96。 yours have a - outside the bracket which eacatly my point

LuJunru commented 6 months ago

bracket

notice that in the paper the ratio is defined as : logpwinpref here my positive_reg = reference_chosen_logps - policy_chosen_logps that would be logprefpwin so the sign is - to make it consistent with the paper

I think it should be +. self.dpop_lambda * torch.relu(reference_chosen_logps - policy_chosen_logps)) is always non-negative, how this can be a penalty term if it always make original loss unchanged or smaller?. besides the idea of the paper is to constrain the chosen response not far from the good response, that been said the reference_chosen_logps should not be smaller than policy_chosen_logps otherwise should be penaltiesed. so I think it's a mistake in the paper. and if the lambda set to be 50, the loss will be way to high, I think it may be another mistake.

I actually trained with my implementation, and I can share my observations just for references. Compared with what I reproduced with original DPO (results), I found the loss of dpop was dramatically large at the beginning, but then drop to a normal level through several steps. However, the final results of dpop was around 3% behind the results of dpo.

have you actually tried text generating with your trained dpop model? because if using -, I feel hard to understand......

I put - in brackets so it equals to +

arkapal3 commented 6 months ago

Hi there - author of DPOP here. Thanks for taking an interest in our paper, and apologies for taking a while to notice this issue.

First I see that there is actually a bracketing error in our paper which resulted in some confusion. In eqn(3), the penalty term - lamda * max(0, reference_chosen_logps - policy_chosen_logps) - should actually be inside the logsigmoid.

So following the TRL loss implementation:

       pi_logratios = policy_chosen_logps - policy_rejected_logps
       logits = pi_logratios - ref_logratios
       penalty_term = torch.maximum(torch.zeros_like(policy_chosen_logps), reference_chosen_logps - policy_chosen_logps)
       logits += - self.lambda * penalty_term

            losses = (
                -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
                - F.logsigmoid(-self.beta * logits) * self.label_smoothing
            )

The idea here is that when policy_chosen_logps is less than reference_chosen_logps, this penalty is a positive value, and with the negative sign we are then trying to reduce this penalty.

We will update the manuscript to fix the bracketing.

LuJunru commented 6 months ago

logits += - self.lambda * penalty_term

Thank you for providing the official realization! If my understanding is correct, it looks like further adding a logsigmoid(beta) operation on your penalty_term in the original eqn(3). I'll have an update experiment to see if it works.

arkapal3 commented 6 months ago

Yes good spot, that is another minor error in our manuscript 👍 We will fix that as well. This may have some repercussions for reported values of lambda, though we have since actually run testing over a wider range of lambda in [5, 500] and found the results to be largely insensitive to its value (at least, within this range).

RikkiXu commented 6 months ago

Hi there - author of DPOP here. Thanks for taking an interest in our paper, and apologies for taking a while to notice this issue.

First I see that there is actually a bracketing error in our paper which resulted in some confusion. In eqn(3), the penalty term - lamda * max(0, reference_chosen_logps - policy_chosen_logps) - should actually be inside the logsigmoid.

So following the TRL loss implementation:

       pi_logratios = policy_chosen_logps - policy_rejected_logps
       logits = pi_logratios - ref_logratios
       penalty_term = torch.maximum(torch.zeros_like(policy_chosen_logps), reference_chosen_logps - policy_chosen_logps)
       logits += - self.lambda * penalty_term

            losses = (
                -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
                - F.logsigmoid(-self.beta * logits) * self.label_smoothing
            )

The idea here is that when policy_chosen_logps is less than reference_chosen_logps, this penalty is a positive value, and with the negative sign we are then trying to reduce this penalty.

We will update the manuscript to fix the bracketing.

Hi~ I tried the new code, but unfortunately both rejected and chosen rewards increased, and the final result on mt-bench will be much lower than dpo. My base model is mistral-7b, and the preference data set is HuggingFaceH4/ultrafeedback_binarized(I set lambda=50)

spttt commented 4 months ago

We will update the manuscript to fix the bracketing.

Is there any new update?

crwhite14 commented 3 months ago

Is there any new update?

Yes, the paper updated on July 3 has the bracketing fixed: https://arxiv.org/pdf/2402.13228