Closed SpursLipu closed 10 months ago
Tagging @kashif and @younesbelkada.
i have to check the dpo_qwen.py
script but seems like some type mis-match between fp16/bf16 ... perhaps you can check if the model is cast to float16/bf16 properly?
Hey @SpursLipu That model uses a custom FA-2 implementation: https://huggingface.co/Qwen/Qwen-14B-Chat/blob/main/modeling_qwen.py#L83 I suggest to open an issue on the Hub repo directly
i have to check the
dpo_qwen.py
script but seems like some type mis-match between fp16/bf16 ... perhaps you can check if the model is cast to float16/bf16 properly?
my dpo_qwen.py please check
import os
from dataclasses import dataclass, field
from typing import Dict, Optional
import torch
from datasets import Dataset, load_dataset
from peft import LoraConfig
from accelerate import Accelerator
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments
from fastchat.conversation import get_conv_template
from trl import DPOTrainer
@dataclass
class ScriptArguments:
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
model_name_or_path: Optional[str] = field(
default="gpt2",
metadata={"help": "the model name"}
)
dataset: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset path"})
trust_remote_code: Optional[bool] = field(default=True, metadata={"help": "trust_remote_code"})
learning_rate: Optional[float] = field(default=1e-3, metadata={"help": "optimizer learning rate"})
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "batch size per device"})
gradient_accumulation_steps: Optional[int] = field(
default=1, metadata={"help": "the number of gradient accumulation steps"}
)
label_pad_token_id: Optional[int] = field(default=-100, metadata={"help": "label for non response tokens"})
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
max_prompt_length: Optional[int] = field(default=128, metadata={"help": "max length of each sample's prompt"})
max_length: Optional[int] = field(default=512, metadata={"help": "max length of each sample"})
report_to: Optional[str] = field(
default=None,
metadata={
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
},
)
ignore_bias_buffers: Optional[bool] = field(
default=False,
metadata={
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
gradient_checkpointing: Optional[bool] = field(
default=False, metadata={"help": "Whether to use gradient checkpointing or no"}
)
gradient_checkpointing_kwargs: Optional[dict] = field(
default=None,
metadata={
"help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`"
},
)
def preprocess(dataset: str, split: str, silent: bool = False, cache_dir: str = None) -> Dataset:
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.
The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
}
Prompts should be structured as follows:
\n\nHuman: <prompt>\n\nAssistant:
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
"""
dataset = load_dataset(dataset, split=split, cache_dir=cache_dir)
conv = get_conv_template("qwen-7b-chat")
def split_prompt_and_responses(sample) -> Dict[str, str]:
chosen = sample["chosen"].split("\n\nAssistant: ")[-1]
rejected = sample["rejected"].split("\n\nAssistant: ")[-1]
prompt = sample["chosen"][len("\n\nHuman: "): sample["chosen"].rfind("\n\nAssistant: ")]
prompt = prompt.replace("\n\nAssistant: ", conv.sep + conv.roles[1] + '\n')
prompt = prompt.replace("\n\nHuman: ", conv.sep + conv.roles[0] + '\n')
prompt = conv.roles[0] + '\n' + prompt + conv.sep + conv.roles[1] + '\n'
return {
"prompt": prompt,
"chosen": chosen,
"rejected": rejected,
}
return dataset.map(split_prompt_and_responses)
if __name__ == "__main__":
global local_rank
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
model = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
device_map={"": Accelerator().local_process_index},
trust_remote_code=script_args.trust_remote_code,
load_in_4bit=True)
if script_args.ignore_bias_buffers:
# torch distributed hack
model._ddp_params_and_buffers_to_ignore = [
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]
model_ref = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
device_map={"": Accelerator().local_process_index},
trust_remote_code=script_args.trust_remote_code,
load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained(
script_args.model_name_or_path,
trust_remote_code=script_args.trust_remote_code,
pad_token='<|endoftext|>',
eos_token='<|im_end|>',
bos_token='<|im_start|>')
train_dataset = preprocess(script_args.dataset, "train")
eval_dataset = preprocess(script_args.dataset, "test")
training_args = TrainingArguments(
per_device_train_batch_size=script_args.per_device_train_batch_size,
remove_unused_columns=False,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
learning_rate=script_args.learning_rate,
evaluation_strategy="steps",
logging_first_step=True,
logging_steps=10, # match results in blog post
eval_steps=500,
output_dir="./test",
optim="adamw_torch",
warmup_steps=150,
report_to=script_args.report_to,
bf16=True,
gradient_checkpointing=script_args.gradient_checkpointing,
)
local_rank = training_args.local_rank
peft_config = LoraConfig(
r=script_args.lora_r,
lora_alpha=script_args.lora_alpha,
lora_dropout=script_args.lora_dropout,
target_modules=[
"q_proj",
"v_proj",
"k_proj",
"out_proj",
"fc_in",
"fc_out",
"wte",
],
bias="none",
task_type="CAUSAL_LM",
)
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=script_args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=peft_config,
max_length=script_args.max_length,
max_prompt_length=script_args.max_prompt_length,
generate_during_eval=False,
)
dpo_trainer.train()
Can you pass BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
in from_pretrained:
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
model = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
device_map={"": Accelerator().local_process_index},
trust_remote_code=script_args.trust_remote_code,
quantization_config=quantization_config)
But I am really not sure this will solve your bug, I just suspect that there might be some weird interaction between the compute dtype and FA-2 on their repository
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
I want to use dpo ft qwen-chat-14b, but I meet the error. The input(q, k, y) type of flash-attention in qwen has to be set as float16 or bfloat16, but in dpo_trainer the type is float32. If I turn off the flash-attention this error will not occur. But training become very slow. How to solve this problem?
Traceback (most recent call last): File "/mnt/afs/smartbrain/FastChat/fastchat/rlhf/dpo_qwen.py", line 215, in
dpo_trainer.train()
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 1555, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 1837, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 2682, in training_step
loss = self.compute_loss(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/trl/trainer/dpo_trainer.py", line 594, in compute_loss
loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/trl/trainer/dpo_trainer.py", line 545, in get_batch_metrics
) = self.concatenated_forward(model, batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/trl/trainer/dpo_trainer.py", line 511, in concatenated_forward
all_logits = model(
^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/utils/operations.py", line 632, in forward
return model_forward(*args, *kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/utils/operations.py", line 620, in call
return convert_to_fp32(self.model_forward(args, kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
return func(*args, kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/peft/peft_model.py", line 918, in forward
return self.base_model(
^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, *kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/peft/tuners/tuners_utils.py", line 94, in forward
return self.model.forward(args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 1108, in forward
transformer_outputs = self.transformer(
^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, *kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 938, in forward
outputs = block(
^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, *kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 639, in forward
attn_outputs = self.attn(
^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, *kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 546, in forward
context_layer = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 174, in forward
assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
AssertionError
0%| | 0/127611 [00:01<?, ?it/s]