huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.92k stars 1.25k forks source link

Flash attention error #973

Closed SpursLipu closed 10 months ago

SpursLipu commented 11 months ago

I want to use dpo ft qwen-chat-14b, but I meet the error. The input(q, k, y) type of flash-attention in qwen has to be set as float16 or bfloat16, but in dpo_trainer the type is float32. If I turn off the flash-attention this error will not occur. But training become very slow. How to solve this problem?

Traceback (most recent call last): File "/mnt/afs/smartbrain/FastChat/fastchat/rlhf/dpo_qwen.py", line 215, in dpo_trainer.train() File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 1555, in train return inner_training_loop( ^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 1837, in _inner_training_loop tr_loss_step = self.training_step(model, inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 2682, in training_step loss = self.compute_loss(model, inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/trl/trainer/dpo_trainer.py", line 594, in compute_loss loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/trl/trainer/dpo_trainer.py", line 545, in get_batch_metrics ) = self.concatenated_forward(model, batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/trl/trainer/dpo_trainer.py", line 511, in concatenated_forward all_logits = model( ^^^^^^ File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/utils/operations.py", line 632, in forward return model_forward(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/utils/operations.py", line 620, in call return convert_to_fp32(self.model_forward(args, kwargs)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast return func(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/peft/peft_model.py", line 918, in forward return self.base_model( ^^^^^^^^^^^^^^^^ File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/peft/tuners/tuners_utils.py", line 94, in forward return self.model.forward(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward output = old_forward(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 1108, in forward transformer_outputs = self.transformer( ^^^^^^^^^^^^^^^^^ File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward output = old_forward(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 938, in forward outputs = block( ^^^^^^ File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward output = old_forward(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 639, in forward attn_outputs = self.attn( ^^^^^^^^^^ File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward output = old_forward(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 546, in forward context_layer = self.core_attention_flash(q, k, v, attention_mask=attention_mask) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 174, in forward assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v))) AssertionError 0%| | 0/127611 [00:01<?, ?it/s]

lvwerra commented 11 months ago

Tagging @kashif and @younesbelkada.

kashif commented 11 months ago

i have to check the dpo_qwen.py script but seems like some type mis-match between fp16/bf16 ... perhaps you can check if the model is cast to float16/bf16 properly?

younesbelkada commented 11 months ago

Hey @SpursLipu That model uses a custom FA-2 implementation: https://huggingface.co/Qwen/Qwen-14B-Chat/blob/main/modeling_qwen.py#L83 I suggest to open an issue on the Hub repo directly

SpursLipu commented 11 months ago

i have to check the dpo_qwen.py script but seems like some type mis-match between fp16/bf16 ... perhaps you can check if the model is cast to float16/bf16 properly?

my dpo_qwen.py please check

import os
from dataclasses import dataclass, field
from typing import Dict, Optional

import torch
from datasets import Dataset, load_dataset
from peft import LoraConfig
from accelerate import Accelerator
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments
from fastchat.conversation import get_conv_template
from trl import DPOTrainer

@dataclass
class ScriptArguments:
    beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})

    model_name_or_path: Optional[str] = field(
        default="gpt2",
        metadata={"help": "the model name"}
    )
    dataset: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset path"})
    trust_remote_code: Optional[bool] = field(default=True, metadata={"help": "trust_remote_code"})
    learning_rate: Optional[float] = field(default=1e-3, metadata={"help": "optimizer learning rate"})
    per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "batch size per device"})
    gradient_accumulation_steps: Optional[int] = field(
        default=1, metadata={"help": "the number of gradient accumulation steps"}
    )
    label_pad_token_id: Optional[int] = field(default=-100, metadata={"help": "label for non response tokens"})
    lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
    lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
    lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})

    max_prompt_length: Optional[int] = field(default=128, metadata={"help": "max length of each sample's prompt"})
    max_length: Optional[int] = field(default=512, metadata={"help": "max length of each sample"})

    report_to: Optional[str] = field(
        default=None,
        metadata={
            "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
            '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
            'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
        },
    )
    ignore_bias_buffers: Optional[bool] = field(
        default=False,
        metadata={
            "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
            "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
        },
    )
    gradient_checkpointing: Optional[bool] = field(
        default=False, metadata={"help": "Whether to use gradient checkpointing or no"}
    )
    gradient_checkpointing_kwargs: Optional[dict] = field(
        default=None,
        metadata={
            "help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`"
        },
    )
def preprocess(dataset: str, split: str, silent: bool = False, cache_dir: str = None) -> Dataset:
    """Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': List[str],
        'chosen': List[str],
        'rejected': List[str],
    }

    Prompts should be structured as follows:
      \n\nHuman: <prompt>\n\nAssistant:
    Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
    """
    dataset = load_dataset(dataset, split=split, cache_dir=cache_dir)
    conv = get_conv_template("qwen-7b-chat")
    def split_prompt_and_responses(sample) -> Dict[str, str]:
        chosen = sample["chosen"].split("\n\nAssistant: ")[-1]
        rejected = sample["rejected"].split("\n\nAssistant: ")[-1]

        prompt = sample["chosen"][len("\n\nHuman: "): sample["chosen"].rfind("\n\nAssistant: ")]
        prompt = prompt.replace("\n\nAssistant: ", conv.sep + conv.roles[1] + '\n')
        prompt = prompt.replace("\n\nHuman: ", conv.sep + conv.roles[0] + '\n')
        prompt = conv.roles[0] + '\n' + prompt + conv.sep + conv.roles[1] + '\n'
        return {
            "prompt": prompt,
            "chosen": chosen,
            "rejected": rejected,
        }

    return dataset.map(split_prompt_and_responses)
if __name__ == "__main__":
    global local_rank
    parser = HfArgumentParser(ScriptArguments)
    script_args = parser.parse_args_into_dataclasses()[0]
    model = AutoModelForCausalLM.from_pretrained(
        script_args.model_name_or_path,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
        device_map={"": Accelerator().local_process_index},
        trust_remote_code=script_args.trust_remote_code,
        load_in_4bit=True)

    if script_args.ignore_bias_buffers:
        # torch distributed hack
        model._ddp_params_and_buffers_to_ignore = [
            name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
        ]

    model_ref = AutoModelForCausalLM.from_pretrained(
        script_args.model_name_or_path,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
        device_map={"": Accelerator().local_process_index},
        trust_remote_code=script_args.trust_remote_code,
        load_in_4bit=True)

    tokenizer = AutoTokenizer.from_pretrained(
        script_args.model_name_or_path,
        trust_remote_code=script_args.trust_remote_code,
        pad_token='<|endoftext|>',
        eos_token='<|im_end|>',
        bos_token='<|im_start|>')

    train_dataset = preprocess(script_args.dataset, "train")

    eval_dataset = preprocess(script_args.dataset, "test")

    training_args = TrainingArguments(
        per_device_train_batch_size=script_args.per_device_train_batch_size,
        remove_unused_columns=False,
        gradient_accumulation_steps=script_args.gradient_accumulation_steps,
        learning_rate=script_args.learning_rate,
        evaluation_strategy="steps",
        logging_first_step=True,
        logging_steps=10,  # match results in blog post
        eval_steps=500,
        output_dir="./test",
        optim="adamw_torch",
        warmup_steps=150,
        report_to=script_args.report_to,
        bf16=True,
        gradient_checkpointing=script_args.gradient_checkpointing,
    )
    local_rank = training_args.local_rank
    peft_config = LoraConfig(
        r=script_args.lora_r,
        lora_alpha=script_args.lora_alpha,
        lora_dropout=script_args.lora_dropout,
        target_modules=[
            "q_proj",
            "v_proj",
            "k_proj",
            "out_proj",
            "fc_in",
            "fc_out",
            "wte",
        ],
        bias="none",
        task_type="CAUSAL_LM",
    )
    dpo_trainer = DPOTrainer(
        model,
        model_ref,
        args=training_args,
        beta=script_args.beta,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        peft_config=peft_config,
        max_length=script_args.max_length,
        max_prompt_length=script_args.max_prompt_length,
        generate_during_eval=False,
    )
    dpo_trainer.train()
younesbelkada commented 11 months ago

Can you pass BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) in from_pretrained:

    quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
    model = AutoModelForCausalLM.from_pretrained(
        script_args.model_name_or_path,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
        device_map={"": Accelerator().local_process_index},
        trust_remote_code=script_args.trust_remote_code,
        quantization_config=quantization_config)

But I am really not sure this will solve your bug, I just suspect that there might be some weird interaction between the compute dtype and FA-2 on their repository

github-actions[bot] commented 10 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.