vwxyzjn / lm-human-preference-details

RLHF implementation details of OAI's 2019 codebase
MIT License
152 stars 7 forks source link

Questions about `left_padding_to_right_padding` #6

Closed liutianlin0121 closed 1 year ago

liutianlin0121 commented 1 year ago

Hi Costa,

Thanks for sharing the awesome implementations! It is tremendously helpful for my own work.

I noticed a few uses of the function left_padding_to_right_padding in reward modeling and policy training. I'm not entirely clear on its purpose, and it would be great if you could help me understand. First, the function left_padding_to_right_padding appears to implement "right_padding_to_left_padding" instead, contrary to its name? Indeed the pad_id are added to the left hand side of non-padding tokens.

I also note that there is a comment

got to convert to right padding, otherwise transformers has weird issues even with position_ids

Could this be connected to the implementation of AutoModelForCausalLMWithRewardHead? The implementation there extracts the reward corresponding to the last token with reward = reward[:, -1]. For this reason, if we pad right, then the reward will correspond to a padding token, causing problems. Not sure if this causes the weird issues you referred to. But if this is the case, then the following lines in GPT2ForSequenceClassification may be helpful, which select the final non-padding tokens instead of the final tokens.

vwxyzjn commented 1 year ago

Hi Tianlin,

Thanks for raising this issue.

left_padding_to_right_padding appears to implement "right_padding_to_left_padding" instead, contrary to its name?

You're absolutely spot on. Would you like to make a PR changing the function names?

Could this be connected to the implementation of AutoModelForCausalLMWithRewardHead? The implementation there extracts the reward corresponding to the last token with reward = reward[:, -1]. For this reason, if we pad right, then the reward will correspond to a padding token, causing problems.

When calculating the logits, OAI’s code works by masking out padding tokens properly. This is achieved by finding out the token indices corresponding to the padding tokens (lm_human_preferences/language/model.py#L296-L297), followed by adjusting their position indices correspondingly (lm_human_preferences/language/model.py#L320).

all_logits [[[ -35.28693   -34.2875    -38.16074  ...  -41.595802  -41.082108
    -35.36577 ]
  [ -35.28693   -34.2875    -38.16074  ...  -41.595802  -41.082108
    -35.36577 ]
  [ -35.28693   -34.2875    -38.16074  ...  -41.595802  -41.082108
    -35.36577 ]
  [-111.303955 -110.94471  -112.90624  ... -113.13064  -113.7788
   -109.17345 ]
  [-111.51512  -109.61077  -114.90231  ... -118.43514  -111.56671
   -112.12478 ]
  [-122.69775  -121.84468  -128.27417  ... -132.28055  -130.39604
   -125.707756]]] (1, 6, 50257)

Here is a snippet of achieving the same thing with left padding.

import torch
import transformers
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2", padding_side="right")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
pad_id = tokenizer.pad_token_id
query = torch.tensor([
    [pad_id, pad_id, 23073],
])
response = torch.tensor([
    [11, 339, 561],
])
temperature = 1.0

query = torch.tensor(query)
response = torch.tensor(response).long()
context_length = query.shape[1]
query_response = torch.cat((query, response), 1)
pretrained_model = transformers.AutoModelForCausalLM.from_pretrained("gpt2")
def forward(policy, query_responses, tokenizer):
    attention_mask = query_responses != tokenizer.pad_token_id
    position_ids = attention_mask.cumsum(1) - attention_mask.long()  # exclusive cumsum
    input_ids = query_responses.clone()
    input_ids[~attention_mask] = 0
    return policy(
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        return_dict=True,
        output_hidden_states=True,
    )
output = forward(pretrained_model, query_response, tokenizer)
logits = output.logits
logits /= temperature
print(logits)

"""
tensor([[[ -26.9395,  -26.4709,  -30.0456,  ...,  -33.2208,  -33.2884,
           -27.4360],
         [ -27.1677,  -26.7330,  -30.2386,  ...,  -33.6813,  -33.6931,
           -27.5928],
         [ -35.2869,  -34.2875,  -38.1608,  ...,  -41.5958,  -41.0821,
           -35.3658],
         [-111.3040, -110.9447, -112.9062,  ..., -113.1306, -113.7788,
          -109.1734],
         [-111.5152, -109.6108, -114.9024,  ..., -118.4352, -111.5668,
          -112.1248],
         [-122.6978, -121.8447, -128.2742,  ..., -132.2805, -130.3961,
          -125.7078]]], grad_fn=<DivBackward0>)
"""

extracts the reward corresponding to the last token with reward = reward[:, -1]

Because of the interaction shown above, extracting the last token's reward would be the same under OAI' setting and our setting.

liutianlin0121 commented 1 year ago

Thank you for your reply! I'm happy to submit a PR soon for the name change.

Here is a snippet of achieving the same thing with left padding.

I fully agree! Everything works fine with your current implementation, which uses reward = reward[:, -1] + left padding. I just wanted to bring up that the unclear issue caused by the right-padding, as mentioned in the comment of the function, could be related to reward = reward[:, -1]. That is, one can either use (1) reward = reward[:, -1] + left padding or (2) reward = reward[:, last_non_padding_token_idx] + right padding. Or perhaps the weird issue of right-padding you mentioned persists even after applying reward = reward[:, last_non_padding_token_idx]?

In any case, I find the current implementation of reward = reward[:, -1] with left padding to be an elegant solution. Thank you for the clarification! I will submit a PR shortly.

vwxyzjn commented 1 year ago

Thank you for your reply! I'm happy to submit a PR soon for the name change.

Here is a snippet of achieving the same thing with left padding.

I fully agree! Everything works fine with your current implementation, which uses reward = reward[:, -1] + left padding. I just wanted to bring up that the unclear issue caused by the right-padding, as mentioned in the comment of the function, could be related to reward = reward[:, -1]. That is, one can either use (1) reward = reward[:, -1] + left padding or (2) reward = reward[:, last_non_padding_token_idx] + right padding. Or perhaps the weird issue of right-padding you mentioned persists even after applying reward = reward[:, last_non_padding_token_idx]?

In any case, I find the current implementation of reward = reward[:, -1] with left padding to be an elegant solution. Thank you for the clarification! I will submit a PR shortly.

I can give a repro of the right padding issue when I get back to my computer 🙂

vwxyzjn commented 1 year ago

Here is the repro

import torch
import transformers

def right_padding_to_left_padding(query, pad_id):
      return torch.tensor([
          [pad_id]*(row==pad_id).sum() + [x for x in row if x != pad_id]
          for row in query
      ])

def repro(left_padding=True):
    tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2", padding_side="right")
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    pad_id = tokenizer.pad_token_id
    query = torch.tensor([
        [23073, pad_id, pad_id],
    ])
    print(f"left_padding={left_padding}")
    if left_padding:
      query = right_padding_to_left_padding(query, pad_id)
    response = torch.tensor([
        [11, 339, 561],
    ])
    temperature = 1.0

    query = torch.tensor(query)
    response = torch.tensor(response).long()
    context_length = query.shape[1]
    query_response = torch.cat((query, response), 1)
    pretrained_model = transformers.AutoModelForCausalLM.from_pretrained("gpt2")
    def forward(policy, query_responses, tokenizer):
        attention_mask = query_responses != tokenizer.pad_token_id
        position_ids = attention_mask.cumsum(1) - attention_mask.long()  # exclusive cumsum
        input_ids = query_responses.clone()
        input_ids[~attention_mask] = 0
        return policy(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=True,
            output_hidden_states=True,
        )
    output = forward(pretrained_model, query_response, tokenizer)
    logits = output.logits
    logits /= temperature
    print(logits)

repro()
repro(left_padding=False)
left_padding=True
tensor([[[ -26.9395,  -26.4709,  -30.0456,  ...,  -33.2208,  -33.2884,
           -27.4360],
         [ -27.1677,  -26.7330,  -30.2386,  ...,  -33.6813,  -33.6931,
           -27.5928],
         [ -35.2869,  -34.2875,  -38.1608,  ...,  -41.5958,  -41.0821,
           -35.3658],
         [-111.3040, -110.9447, -112.9062,  ..., -113.1306, -113.7788,
          -109.1734],
         [-111.5152, -109.6108, -114.9024,  ..., -118.4352, -111.5668,
          -112.1248],
         [-122.6978, -121.8447, -128.2742,  ..., -132.2805, -130.3961,
          -125.7078]]], grad_fn=<DivBackward0>)
left_padding=False
tensor([[[ -35.2869,  -34.2875,  -38.1608,  ...,  -41.5958,  -41.0821,
           -35.3658],
         [ -48.6581,  -48.7161,  -52.8530,  ...,  -56.9125,  -55.6985,
           -49.6360],
         [ -48.6581,  -48.7161,  -52.8530,  ...,  -56.9125,  -55.6985,
           -49.6360],
         [-111.3040, -110.9447, -112.9062,  ..., -113.1306, -113.7788,
          -109.1734],
         [-111.5152, -109.6108, -114.9024,  ..., -118.4352, -111.5668,
          -112.1248],
         [-122.6978, -121.8447, -128.2742,  ..., -132.2805, -130.3961,
          -125.7078]]], grad_fn=<DivBackward0>)

The ground truth of running OpenAI's code is

all_logits [[[ -35.28693   -34.2875    -38.16074  ...  -41.595802  -41.082108
    -35.36577 ]
  [ -35.28693   -34.2875    -38.16074  ...  -41.595802  -41.082108
    -35.36577 ]
  [ -35.28693   -34.2875    -38.16074  ...  -41.595802  -41.082108
    -35.36577 ]
  [-111.303955 -110.94471  -112.90624  ... -113.13064  -113.7788
   -109.17345 ]
  [-111.51512  -109.61077  -114.90231  ... -118.43514  -111.56671
   -112.12478 ]
  [-122.69775  -121.84468  -128.27417  ... -132.28055  -130.39604
   -125.707756]]] (1, 6, 50257)

Notice we are trying to get -35.28693, -111.303955, -111.51512 as the first number in response logits. With the left padding we get the same, but with right padding we get -48.6581, -111.3040, -111.5152