CarperAI / trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
MIT License
4.51k stars 471 forks source link

Increasing max new tokens for generation arguments lead to errors #553

Open wise-east opened 1 year ago

wise-east commented 1 year ago

🐛 Describe the bug

Here's my TrainConfig:

default_config = TRLConfig(
    train=TrainConfig(
        seq_length=512,
        epochs=10000,
        total_steps=10000,
        batch_size=8,
        checkpoint_interval=10000,
        eval_interval=500,
        pipeline="PromptPipeline",
        trainer="AcceleratePPOTrainer",
        checkpoint_dir="checkpoints/ppo_hh",
    ),
    model=ModelConfig(model_path="tiiuae/falcon-7b-instruct", num_layers_unfrozen=2),
    tokenizer=TokenizerConfig(tokenizer_path="tiiuae/falcon-7b-instruct", truncation_side="left"),
    optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)),
    scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1e-6)),
    method=PPOConfig(
        name="PPOConfig",
        num_rollouts=64,
        chunk_size=4,
        ppo_epochs=4,
        init_kl_coef=0.05,
        target=6,
        horizon=10000,
        gamma=1,
        lam=0.95,
        cliprange=0.2,
        cliprange_value=0.2,
        vf_coef=1,
        scale_reward="running",
        ref_mean=None,
        ref_std=None,
        cliprange_reward=10,
        gen_kwargs=dict(
            max_new_tokens=128,
            top_k=0,
            top_p=1.0,
            do_sample=True,
        ),
    ),
)

Simply changing max_new_tokens from 128 to 256 leads to an error:

 File "/home/ec2-user/trlx/examples/hh/ppo_hh.py", line 263, in <module>
    main(hparams)
  File "/home/ec2-user/trlx/examples/hh/ppo_hh.py", line 252, in main
    trlx.train(
  File "/home/ec2-user/trlx/trlx/trlx.py", line 129, in train
    trainer.learn()
  File "/home/ec2-user/trlx/trlx/trainer/accelerate_base_trainer.py", line 521, in learn
    self.prepare_learning()
  File "/home/ec2-user/trlx/trlx/trainer/accelerate_ppo_trainer.py", line 234, in prepare_learning
    self.make_experience(self.config.method.num_rollouts)
  File "/home/ec2-user/trlx/trlx/trainer/accelerate_ppo_trainer.py", line 283, in make_experience
    samples = self.generate(batch["input_ids"], batch["attention_mask"])
  File "/home/ec2-user/trlx/trlx/trainer/accelerate_base_trainer.py", line 263, in generate
    return self.accelerator.unwrap_model(self.model).generate(
Traceback (most recent call last):
  File "/home/ec2-user/trlx/trlx/models/modeling_ppo.py", line 353, in generate
    return self.base_model.generate(*args, **kwargs)
  File "/home/ec2-user/trlx/examples/hh/ppo_hh.py", line 263, in <module>
  File "/opt/conda/envs/vllm/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/envs/vllm/lib/python3.9/site-packages/transformers/generation/utils.py", line 1476, in generate
    and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
    main(hparams)
  File "/home/ec2-user/trlx/examples/hh/ppo_hh.py", line 252, in main
IndexError: index -1 is out of bounds for dimension 1 with size 0

Any help is appreciated!

Which trlX version are you using?

0.7.0

Additional system and package information

Python=3.9

maxreciprocate commented 1 year ago

Hello there @wise-east! You seem to be working with a modified code to support the falcon architecture, have you maybe resolved #532 on your own?

I tested your config with only change of replacing model tiiuae/falcon-7b-instruct with reciprocate/tiny-llama while keeping max_new_tokens=256 and I couldn't observe this error, so it seems to be falcon specific perhaps. Also which version of transformers did you use?

wise-east commented 1 year ago

Yes, I adapted modeling_ppo.py by adding a FalconModelBranch based on what was done for other branches.


from transformers.models.falcon import modeling_falcon

class FalconModelBranch(ModelBranch):

    def _prepare_attn_mask(
        self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
    ) -> torch.BoolTensor:
        # create causal mask
        # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
        combined_attention_mask = None
        device = attention_mask.device
        _, src_length = input_shape

        if src_length > 1:
            combined_attention_mask = modeling_falcon._make_causal_mask(
                input_shape, device=device, past_key_values_length=past_key_values_length
            )

        # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
        expanded_attn_mask = modeling_falcon._expand_mask(attention_mask, tgt_length=src_length)
        combined_attention_mask = (
            expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
        )

        return combined_attention_mask

    def forward(
        self,
        hidden_states: torch.Tensor,
        output_shape: torch.Tensor,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = False,        
    ): 
        """
        Reference: https://huggingface.co/tiiuae/falcon-7b-instruct/blob/main/modelling_RW.py
        """

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        batch_size, seq_length = hidden_states.shape[:2]

        if past_key_values is None:
            past_key_values = tuple([None] * len(self.decoder_blocks))

        head_mask = self.get_head_mask(head_mask, hf_get_num_hidden_layers(self.config))

        presents = () if use_cache else None
        all_self_attentions = () if output_attentions else None
        all_hidden_states = () if output_hidden_states else None

        # Compute alibi tensor: check build_alibi_tensor documentation
        seq_length_with_past = seq_length
        past_key_values_length = 0
        if past_key_values[0] is not None:
            past_key_values_length = past_key_values[0][0].shape[2]
            seq_length_with_past = seq_length_with_past + past_key_values_length
        if attention_mask is None:
            attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
        else:
            attention_mask = attention_mask.to(hidden_states.device)

        alibi = modeling_falcon.build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype)

        causal_mask = self._prepare_attn_mask(
            attention_mask,
            input_shape=(batch_size, seq_length),
            past_key_values_length=past_key_values_length,
        )

        for i, (block, layer_past) in enumerate(zip(self.decoder_blocks, past_key_values)):

            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        # None for past_key_value
                        return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)

                    return custom_forward

                outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(block),
                    hidden_states,
                    alibi,
                    causal_mask,
                    head_mask[i],
                )
            else:
                outputs = block(
                    hidden_states,
                    layer_past=layer_past,
                    attention_mask=causal_mask,
                    head_mask=head_mask[i],
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                    alibi=alibi,
                )

            hidden_states = outputs[0]
            if use_cache is True:
                presents = presents + (outputs[1],)

            if output_attentions:
                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)

        # Add last hidden state
        hidden_states = self.final_norm(hidden_states)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        lm_logits = self.lm_head(hidden_states)

        if not return_dict:
            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

        return CausalLMOutputWithValue(
            logits=lm_logits,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )

where modeling_falcon comes from https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py.

My pip show transformers result:

Name: transformers
Version: 4.32.1
Summary: State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Home-page: https://github.com/huggingface/transformers
Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)
Author-email: transformers@huggingface.co
License: Apache 2.0 License
Location: /opt/conda/envs/vllm/lib/python3.9/site-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm
Required-by: peft, trl, trlx
maxreciprocate commented 1 year ago

Oh that's good! However I'm seeing a bit different version of _expand_mask (https://github.com/huggingface/transformers/blob/aea761499f4b1193f2706f471442da6f9df65d65/src/transformers/models/falcon/modeling_falcon.py#L197) which doesn't take tgt_length as an argument on 4.32.1 and also on master. If it's not too much work could you open a pr with your changes? 🙏