OpenLLMAI / OpenRLHF

An Easy-to-use, Scalable and High-performance RLHF Framework (70B+ PPO Full Tuning & Iterative DPO & LoRA & Mixtral)
https://openrlhf.readthedocs.io/
Apache License 2.0
1.71k stars 160 forks source link

Incompatibility with Qwen #303

Closed Ricardokevins closed 1 month ago

Ricardokevins commented 1 month ago

When using PPO for training, I noticed that within the actor's generate function, it locates the position of the EOS token and sets all subsequent tokens to the EOS token. However, this results in even when the padding side is set to left, the far right end of the sequence still contains padding tokens or the EOS token. This has triggered an internal assert in the Qwen model.

Actor.py

def process_sequences(self, sequences: torch.Tensor, input_len, eos_token_id, pad_token_id):
        attention_mask = (sequences.ne(eos_token_id) & sequences.ne(pad_token_id)).to(dtype=torch.long)
        seq_length = attention_mask.size(1)

        # The following code is equivalent to:
        #
        # for i in range(attention_mask.size(0)):
        #     for t in reversed(range(seq_length)):
        #         if attention_mask[i][t] > 0.5:
        #             attention_mask[i][min(t + 1, seq_length - 1)] = True
        #             sequences[i][min(t + 1, seq_length - 1)] = eos_token_id
        #             break
        #
        eos_indices = seq_length - attention_mask.long().fliplr().argmax(dim=1, keepdim=True).clamp(min=1)
        attention_mask.scatter_(dim=1, index=eos_indices, value=1)
        sequences.scatter_(dim=1, index=eos_indices, value=eos_token_id)

        # in RL, state_i (current token) + action_i (next token) -> state_i+1 (next token)
        state_seq = sequences[:, input_len - 1 : -1]
        # we only calculate the loss of state_i != eos | pad
        action_mask = state_seq.ne(eos_token_id) & state_seq.ne(pad_token_id)
        return sequences, attention_mask, action_mask

modeling_qwen.py

if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
    is_padding_right = attention_mask[:, -1].sum().item() != batch_size
    if is_padding_right:
        raise ValueError(
            "You are attempting to perform batched generation with padding_side='right'"
            " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
            " call `tokenizer.padding_side  = 'left'` before tokenizing the input. "
        )
hijkzzz commented 1 month ago

just disable flash_attention_2

Ricardokevins commented 1 month ago

just disable flash_attention_2

Yes, but it may slow down the training

Luckly, Llama3 works fine