philschmid / deep-learning-pytorch-huggingface

MIT License
580 stars 138 forks source link

Fix Flash Attention forward for Llama-2 70b #30

Closed davidmrau closed 11 months ago

davidmrau commented 11 months ago

This pull request fixes #28 when training LLama-2 70b with LoRA and flash attention.

davidmrau commented 11 months ago

@philschmid can you have a look?

philschmid commented 11 months ago

~Can you please explain why this should fix flash attention for GQA?~ Ohmen Github only has shown the last commit... i ll take a look.

davidmrau commented 11 months ago

LGTM. Can you share a script to quickly reproduce it?

here you go: d91f6ddf2a0a05e92ebc865d8fe1e42b9eac7849

davidmrau commented 11 months ago

LGTM. Can you share a script to quickly reproduce it?

here you go: d91f6dd

actually, we don't want to have this in the repo so here:

from datasets import load_dataset
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from transformers import TrainingArguments
from trl import SFTTrainer
from random import randrange
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Load dataset from the hub
dataset = load_dataset("databricks/databricks-dolly-15k", split="train")

print(f"dataset size: {len(dataset)}")
print(dataset[randrange(len(dataset))])
# dataset size: 15011

def format_instruction(sample):
    return f"""### Instruction:
Use the Input below to create an instruction, which could have been used to generate the input using an LLM.

### Input:
{sample['response']}

### Response:
{sample['instruction']}
"""

use_flash_attention = True
# COMMENT IN TO USE FLASH ATTENTION
# replace attention with flash attention
if torch.cuda.get_device_capability()[0] >= 8:
    from utils.llama_patch import replace_attn_with_flash_attn
    print("Using flash attention")
    replace_attn_with_flash_attn()

# Hugging Face model id
#model_id = "NousResearch/Llama-2-70b-hf" # non-gated
model_id = "meta-llama/Llama-2-70b-hf" # gated

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, use_cache=False, device_map="auto")
model.config.pretraining_tp = 1

# Validate that the model is using flash attention, by comparing doc strings
if use_flash_attention:
    from utils.llama_patch import forward
    assert model.model.layers[0].self_attn.forward.__doc__ == forward.__doc__, "Model is not using flash attention"

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# LoRA config based on QLoRA paper
peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.1,
        r=64,
        bias="none",
        task_type="CAUSAL_LM",
)

args = TrainingArguments(
    output_dir="llama-70-int4-dolly",
    num_train_epochs=3,
    per_device_train_batch_size=6 if use_flash_attention else 4,
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    optim="paged_adamw_32bit",
    logging_steps=10,
    save_strategy="epoch",
    learning_rate=2e-4,
    bf16=True,
    tf32=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
#    disable_tqdm=True # disable tqdm since with packing values are in correct
)

# prepare model for training
model = prepare_model_for_kbit_training(model)

# Upcast layer for flash attnetion
if use_flash_attention:
    from utils.llama_patch import upcast_layer_for_flash_attention
    torch_dtype = torch.bfloat16 if args.bf16 else torch.float16 if args.fp16 else torch.float32
    model = upcast_layer_for_flash_attention(model, torch_dtype)

model = get_peft_model(model, peft_config)
max_seq_length = 2048 # max sequence length for model and packing of the dataset

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    formatting_func=format_instruction,
    args=args,
)

# train
trainer.train()
philschmid commented 10 months ago

Neeeded to revert this since the training for 7B became 10x worse and inference wasn't working anymore

davidmrau commented 10 months ago

For me inference with the 7B model works just the same

davidmrau commented 10 months ago

I will look into it next week when I have more time

philschmid commented 10 months ago

Yeah inference worked, was due flash attetnion being loaded. But running the example with the suggested patch from this branch results into poor performance. Will also try to see from what it is coming from if i have time.