Closed davidmrau closed 11 months ago
@philschmid can you have a look?
~Can you please explain why this should fix flash attention for GQA?~ Ohmen Github only has shown the last commit... i ll take a look.
LGTM. Can you share a script to quickly reproduce it?
here you go: d91f6ddf2a0a05e92ebc865d8fe1e42b9eac7849
LGTM. Can you share a script to quickly reproduce it?
here you go: d91f6dd
actually, we don't want to have this in the repo so here:
from datasets import load_dataset
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from transformers import TrainingArguments
from trl import SFTTrainer
from random import randrange
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
# Load dataset from the hub
dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
print(f"dataset size: {len(dataset)}")
print(dataset[randrange(len(dataset))])
# dataset size: 15011
def format_instruction(sample):
return f"""### Instruction:
Use the Input below to create an instruction, which could have been used to generate the input using an LLM.
### Input:
{sample['response']}
### Response:
{sample['instruction']}
"""
use_flash_attention = True
# COMMENT IN TO USE FLASH ATTENTION
# replace attention with flash attention
if torch.cuda.get_device_capability()[0] >= 8:
from utils.llama_patch import replace_attn_with_flash_attn
print("Using flash attention")
replace_attn_with_flash_attn()
# Hugging Face model id
#model_id = "NousResearch/Llama-2-70b-hf" # non-gated
model_id = "meta-llama/Llama-2-70b-hf" # gated
# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, use_cache=False, device_map="auto")
model.config.pretraining_tp = 1
# Validate that the model is using flash attention, by comparing doc strings
if use_flash_attention:
from utils.llama_patch import forward
assert model.model.layers[0].self_attn.forward.__doc__ == forward.__doc__, "Model is not using flash attention"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# LoRA config based on QLoRA paper
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
)
args = TrainingArguments(
output_dir="llama-70-int4-dolly",
num_train_epochs=3,
per_device_train_batch_size=6 if use_flash_attention else 4,
gradient_accumulation_steps=2,
gradient_checkpointing=True,
optim="paged_adamw_32bit",
logging_steps=10,
save_strategy="epoch",
learning_rate=2e-4,
bf16=True,
tf32=True,
max_grad_norm=0.3,
warmup_ratio=0.03,
lr_scheduler_type="constant",
# disable_tqdm=True # disable tqdm since with packing values are in correct
)
# prepare model for training
model = prepare_model_for_kbit_training(model)
# Upcast layer for flash attnetion
if use_flash_attention:
from utils.llama_patch import upcast_layer_for_flash_attention
torch_dtype = torch.bfloat16 if args.bf16 else torch.float16 if args.fp16 else torch.float32
model = upcast_layer_for_flash_attention(model, torch_dtype)
model = get_peft_model(model, peft_config)
max_seq_length = 2048 # max sequence length for model and packing of the dataset
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
max_seq_length=max_seq_length,
tokenizer=tokenizer,
packing=True,
formatting_func=format_instruction,
args=args,
)
# train
trainer.train()
Neeeded to revert this since the training for 7B became 10x worse and inference wasn't working anymore
For me inference with the 7B model works just the same
I will look into it next week when I have more time
Yeah inference worked, was due flash attetnion being loaded. But running the example with the suggested patch from this branch results into poor performance. Will also try to see from what it is coming from if i have time.
This pull request fixes #28 when training LLama-2 70b with LoRA and flash attention.