PhiForCausalLM does not support Flash Attention 2.0 #28381

gmittal commented 9 months ago
import torch
from transformers import AutoModelForCausalLM, AutoModel

model = AutoModelForCausalLM.from_pretrained(


ValueError: PhiForCausalLM does not support Flash Attention 2.0 yet. Please open an issue on GitHub to request support for this architecture:
rootonchair commented 9 months ago

Hi, I would like to work on this issue

NielsRogge commented 9 months ago

Support for Phi-2 is still WIP, you can follow the progress here:

susnato commented 9 months ago

Hi @gmittal, Flash Attention is already implemented for Phi, PR

It seems that you are using the hub version of phi-2. Please use it from the library to properly enable Flash Attention. For now microsoft/phi-2, does not have the correct order of the weights to be used with the library model so please use it from susnato/phi-2.

First update to the latest transformers version -

pip install -U transformers

then run -

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("susnato/phi-2", 

tokenizer = AutoTokenizer.from_pretrained("susnato/phi-2")

inputs = tokenizer('''def print_prime(n):
   Print all primes between 1 and n
   """''', return_tensors="pt", return_attention_mask=False)

outputs = model.generate(**inputs, max_length=200)
text = tokenizer.batch_decode(outputs)[0]

Let me know if this works or not.

nakranivaibhav commented 9 months ago

I would like to work on this issue

NicolasMejiaPetit commented 9 months ago

Using HF alignment notebook, DPO script gives me this error regardless of transformers version. (I already force updated with pip). When I remove flash attention from the yaml it works (after a bit of code adjustment). I am able to fine tune with one of my sft scripts using flash attention, which is the strange part.

gugarosa commented 9 months ago

Hello everyone!

This should be fixed in transformers If not using that version, please make sure that trust_remote_code=True when loading the model and it should work out-of-the-box with flash-attention 2.

NielsRogge commented 9 months ago

Thanks! Closing as this was fixed in

NicolasMejiaPetit commented 9 months ago

I installed from source so now i am on transformers 4.37.dev0 and i am still getting the Incompatible error, even with trust remote code set to true.

C:\Users\PC\Documents\Code-Trainer\FineTune>py --model_name_or_path C:\Users\PC\Documents\NEWGEN\text-generation-webui-main\models\dolphin-2_6-phi-2 --data_path MiniCoderW.json --output_dir C:\Users\PC\Documents\NEWGEN\text-generation-webui-main\models\TrainedPhi --num_train_epochs 3 --model_max_length 1024 --per_device_train_batch_size 1 --evaluation_strategy "no" --save_strategy "steps" --save_steps 1000 --save_total_limit 10 --learning_rate 2e-5 --warmup_steps 10 --logging_steps 10 --lr_scheduler_type "cosine" --report_to "tensorboard" --bf16 False --dataloader_num_workers 12 --optim paged_adamw_8bit

ValueError: PhiForCausalLM does not support Flash Attention 2.0 yet. Please request to add support where the model is hosted, on its model hub page:\Users\PC\Documents\NEWGEN\text-generation-webui-main\models\dolphin-2_6-phi-2/discussions/new or in the Transformers GitHub repo:

Here is the script I am using:

`import copy import random from dataclasses import dataclass, field from typing import Optional, Dict, Sequence

import torch import transformers from transformers import Trainer from datasets import load_dataset


def build_instruction_prompt(instruction: str): return ''' You are an AI programming assistant, utilizing the DeepSeek Coder model, developed by DeepSeek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.





@dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="deepseek-ai/deepseek-coder-6.7b-instruct")

@dataclass class DataArguments: data_path: str = field(default=None, metadata={"help": "Path to the training data."})

@dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") model_max_length: int = field( default=512, metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, )

def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): """Collects the state dict and dump to disk.""" state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa

def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: """Tokenize a list of strings.""" tokenized_list = [ tokenizer( text, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ) for text in strings ]

input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [ for tokenized in tokenized_list

return dict(

def preprocess( sources: Sequence[str], targets: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: """Preprocess the data by tokenizing.""" examples = [s + t for s, t in zip(sources, targets)] examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] input_ids = examples_tokenized["input_ids"]

labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
    label[:source_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=labels)

@dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer

def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
    input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
    input_ids = [torch.tensor(x) for x in input_ids]
    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
    labels = [torch.tensor(x) for x in labels]
    labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)

    return dict(

def train_tokenize_function(examples, tokenizer): sources = [ build_instruction_prompt(instruction) for instruction in examples['instruction'] ] targets = [f"{output}\n{EOT_TOKEN}" for output in examples['output']] data_dict = preprocess(sources, targets, tokenizer) return data_dict

def train(): parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if training_args.local_rank == 0:

tokenizer = transformers.AutoTokenizer.from_pretrained(
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

print("PAD Token:", tokenizer.pad_token, tokenizer.pad_token_id)
print("BOS Token", tokenizer.bos_token, tokenizer.bos_token_id)
print("EOS Token", tokenizer.eos_token, tokenizer.eos_token_id)

if training_args.local_rank == 0:
    print("Load tokenizer from {} over.".format(model_args.model_name_or_path))

model = transformers.AutoModelForCausalLM.from_pretrained(

if training_args.local_rank == 0:
    print("Load model from {} over.".format(model_args.model_name_or_path))

raw_train_datasets = load_dataset(

train_dataset =
    load_from_cache_file=True, # not args.overwrite_cache
    desc="Running Encoding",
    fn_kwargs={ "tokenizer": tokenizer }

if training_args.local_rank == 0:
    print("Training dataset samples:", len(train_dataset))
    for index in random.sample(range(len(train_dataset)), 3):
        print(f"Sample {index} of the training set: {train_dataset[index]['input_ids']}, {train_dataset[index]['labels']}.")
        print(f"Sample {index} of the training set: {tokenizer.decode(list(train_dataset[index]['input_ids']))}.")

data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
data_module = dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)

trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)

safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)

if name == "main": train() `

NielsRogge commented 9 months ago

Hi @NickWithBotronics if you set trust_remote_code=True, then the code from the hub is used (in case of microsoft/phi-2 that's defined here), rather than defined natively in the Transformers library.

Hence it's recommended to convert the weights from the microsoft/phi-2 repo to a native one, which will work with Flash Attention 2. One can leverage the conversion script for that.

@ArthurZucker should we host the converted phi-2 weights as part of the Microsoft organization? Cause currently one will get a lot of mismatched keys when doing the following:

import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(

due to the the model in Transformers using a single matrix for queries, keys and values wheras the code on the hub uses separate matrices.

NicolasMejiaPetit commented 9 months ago

Thank you <3 !!!! that fixed that error(using the new and converted hf format), now onto a new error that's due to my script I think?. :( 
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with'cuda').
AttributeError: 'PhiConfig' object has no attribute 'attention_dropout'

C:\Users\PC\Documents\Code-Trainer\FineTune> "

Edit: fixed it by downloading the latest: Generation_config.json, Config.json,, and

NicolasMejiaPetit commented 9 months ago

while I got it working, the training loss was very wack. It started at 6 and went to 2 (after 3 epochs) but when I used the old config with out flash attention it was .6 to ~.29(also 3 epochs) same dataset same set up, same model. Just different config files and flash attention. I saw someone else experience the same thing on twitter.

ArthurZucker commented 9 months ago

Can you open a seperate issue for this? With a reproducible snippet

NicolasMejiaPetit commented 9 months ago

Gotcha, I’ll move to this ticket #28488