microsoft / Phi-3CookBook

This is a Phi-3 book for getting started with Phi-3. Phi-3, a family of open AI models developed by Microsoft. Phi-3 models are the most capable and cost-effective small language models (SLMs) available, outperforming models of the same size and next size up across a variety of language, reasoning, coding, and math benchmarks.
MIT License
2.28k stars 228 forks source link

Flash Attention supports only fp16 and bf16 data type for Phi-3-small-128K fine-tuning using QLoRA #127

Closed ArpitSharma7 closed 1 month ago

ArpitSharma7 commented 1 month ago

Please provide us with the following information:

This issue is for a: (mark with an x)

- [x] bug report -> please search issues before submitting
- [ ] feature request
- [ ] documentation issue or request
- [ ] regression (a behavior that used to work and stopped in a new release)

Minimal steps to reproduce

import os import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, pipeline, logging, ) from transformers import Trainer from peft import LoraConfig, PeftModel from datasets import Dataset import pandas as pd from peft import prepare_model_for_kbit_training

import datasets from datasets import Dataset from datasets import load_dataset, concatenate_datasets import numpy as np import ast import sys

import warnings warnings.filterwarnings("ignore")

model_name = "microsoft/Phi-3-small-128k-instruct"

import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

BitsAndBytesConfig int-4 config

bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 )

Load model and tokenizer

model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, torch_dtype= torch.bfloat16, attn_implementation="flash_attention_2", cache_dir = "/data", trust_remote_code=True, ) model.config.pretraining_tp = 1

tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir = "/data", trust_remote_code=True, use_fast=True,) tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = 'right' df_raw = pd.read_csv("train.csv") df = df_raw[['Question','Query', 'Schema']]

df_sorted = df.sort_values(by='Query', key=lambda x: x.str.len())

dataset = Dataset.from_pandas(df_sorted[['Question', 'Query', 'Schema']])

def prepare_dialogue_mistral(example):

question = example["Question"]
response = example["Query"]
context = example['Schema']

prompt_file = "prompt_phi3.md"

with open(prompt_file, "r") as f:
    prompt = f.read()

prompt = prompt.format(
    user_question=question, table_metadata_string=context, sql_query=response
)
example['text']=prompt

return example

dataset_formatted = dataset.map(prepare_dialogue_mistral, num_proc=4, remove_columns=[ 'Question', 'Query','Schema'])

test_df = pd.read_csv("val.csv") test_df = test_df[['Question','Query', 'Schema']]

test_dataset = Dataset.from_pandas(test_df[['Question', 'Query', 'Schema']]) test_dataset_formatted = test_dataset.map(prepare_dialogue_mistral, num_proc=4, remove_columns=[ 'Question', 'Query','Schema'])

from datasets import Dataset, DatasetDict

dataset_split = DatasetDict({"train": dataset_formatted, "test": test_dataset_formatted})

from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model peft_config = LoraConfig( target_modules = ["query_key_value","dense","down_proj","up_proj", "lm_head"], lora_alpha=128, lora_dropout=0.05, r=128, bias="none", task_type="CAUSAL_LM" )

model.gradient_checkpointing_enable() model = prepare_model_for_kbit_training(model)

from transformers import TrainingArguments

args = TrainingArguments( output_dir="phi3-small_res", num_train_epochs=3, per_device_train_batch_size=2, per_device_eval_batch_size=1, gradient_accumulation_steps=2, gradient_checkpointing=False, optim="adamw_torch", learning_rate=5e-06, bf16=True, max_grad_norm=0.3, logging_steps = 560, evaluation_strategy = "steps", save_strategy='epoch', warmup_ratio=0.01, lr_scheduler_type="cosine", )

from trl import SFTTrainer max_seq_length = 4096

trainer = SFTTrainer( model=model, train_dataset=dataset_split['train'], eval_dataset=dataset_split['test'], data_collator=collator, dataset_text_field="text", peft_config=peft_config, max_seq_length=max_seq_length, tokenizer=tokenizer, packing=False, args=args )

trainer.train()

Any log messages given by the failure

RuntimeError Traceback (most recent call last) Cell In[2], line 159 140 collator = DataCollatorForCompletionOnlyLM( 141 response_template_ids, tokenizer=tokenizer 142 ) 144 trainer = SFTTrainer( 145 model=model, 146 train_dataset=dataset_split['train'], (...) 156 # callbacks=[EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.01)] 157 ) --> 159 trainer.train()

File /data/phi3-small-venv/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:451, in SFTTrainer.train(self, *args, *kwargs) 448 if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune: 449 self.model = self._trl_activate_neftune(self.model) --> 451 output = super().train(args, **kwargs) 453 # After training we make sure to retrieve back the original forward pass method 454 # for the embedding layer by removing the forward post hook. 455 if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:

File /data/phi3-small-venv/lib/python3.10/site-packages/transformers/trainer.py:1885, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs) 1883 hf_hub_utils.enable_progress_bars() 1884 else: -> 1885 return inner_training_loop( 1886 args=args, 1887 resume_from_checkpoint=resume_from_checkpoint, 1888 trial=trial, 1889 ignore_keys_for_eval=ignore_keys_for_eval, 1890 )

File /data/phi3-small-venv/lib/python3.10/site-packages/transformers/trainer.py:2216, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval) 2213 self.control = self.callback_handler.on_step_begin(args, self.state, self.control) 2215 with self.accelerator.accumulate(model): -> 2216 tr_loss_step = self.training_step(model, inputs) 2218 if ( 2219 args.logging_nan_inf_filter 2220 and not is_torch_xla_available() 2221 and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) 2222 ): 2223 # if loss is nan or inf simply add the average of previous logged losses 2224 tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File /data/phi3-small-venv/lib/python3.10/site-packages/transformers/trainer.py:3238, in Trainer.training_step(self, model, inputs) 3235 return loss_mb.reduce_mean().detach().to(self.args.device) 3237 with self.compute_loss_context_manager(): -> 3238 loss = self.compute_loss(model, inputs) 3240 del inputs 3241 torch.cuda.empty_cache()

File /data/phi3-small-venv/lib/python3.10/site-packages/transformers/trainer.py:3264, in Trainer.compute_loss(self, model, inputs, return_outputs) 3262 else: 3263 labels = None -> 3264 outputs = model(**inputs) 3265 # Save past state if it exists 3266 # TODO: this needs to be fixed and made cleaner later. 3267 if self.args.past_index >= 0:

File /data/phi3-small-venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, kwargs) 1530 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1531 else: -> 1532 return self._call_impl(args, kwargs)

File /data/phi3-small-venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, *kwargs) 1536 # If we don't have any hooks, we want to skip the rest of the logic in 1537 # this function, and just call forward. 1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1539 or _global_backward_pre_hooks or _global_backward_hooks 1540 or _global_forward_hooks or _global_forward_pre_hooks): -> 1541 return forward_call(args, **kwargs) 1543 try: 1544 result = None

File /data/phi3-small-venv/lib/python3.10/site-packages/accelerate/utils/operations.py:819, in convert_outputs_to_fp32..forward(*args, kwargs) 818 def forward(*args, *kwargs): --> 819 return model_forward(args, kwargs)

File /data/phi3-small-venv/lib/python3.10/site-packages/accelerate/utils/operations.py:807, in ConvertOutputsToFp32.call(self, *args, kwargs) 806 def call(self, *args, *kwargs): --> 807 return convert_to_fp32(self.model_forward(args, kwargs))

File /data/phi3-small-venv/lib/python3.10/site-packages/torch/amp/autocast_mode.py:16, in autocast_decorator..decorate_autocast(*args, kwargs) 13 @functools.wraps(func) 14 def decorate_autocast(*args, *kwargs): 15 with autocast_instance: ---> 16 return func(args, kwargs)

File /data/phi3-small-venv/lib/python3.10/site-packages/peft/peft_model.py:1577, in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, kwargs) 1575 with self._enable_peft_forward_hooks(kwargs): 1576 kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args} -> 1577 return self.base_model( 1578 input_ids=input_ids, 1579 attention_mask=attention_mask, 1580 inputs_embeds=inputs_embeds, 1581 labels=labels, 1582 output_attentions=output_attentions, 1583 output_hidden_states=output_hidden_states, 1584 return_dict=return_dict, 1585 **kwargs, 1586 ) 1588 batch_size = _get_batch_size(input_ids, inputs_embeds) 1589 if attention_mask is not None: 1590 # concat prompt attention mask

File /data/phi3-small-venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, kwargs) 1530 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1531 else: -> 1532 return self._call_impl(args, kwargs)

File /data/phi3-small-venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, *kwargs) 1536 # If we don't have any hooks, we want to skip the rest of the logic in 1537 # this function, and just call forward. 1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1539 or _global_backward_pre_hooks or _global_backward_hooks 1540 or _global_forward_hooks or _global_forward_pre_hooks): -> 1541 return forward_call(args, **kwargs) 1543 try: 1544 result = None

File /data/phi3-small-venv/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:188, in BaseTuner.forward(self, *args, kwargs) 187 def forward(self, *args: Any, *kwargs: Any): --> 188 return self.model.forward(args, kwargs)

File /data/phi3-small-venv/lib/python3.10/site-packages/accelerate/hooks.py:169, in add_hook_to_module..new_forward(module, *args, kwargs) 167 output = module._old_forward(*args, *kwargs) 168 else: --> 169 output = module._old_forward(args, kwargs) 170 return module._hf_hook.post_forward(module, output)

File ~/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-small-128k-instruct/f80aaa30bfc64c2b8ab214b541d9050e97163bc4/modeling_phi3_small.py:956, in Phi3SmallForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict) 953 return_dict = return_dict if return_dict is not None else self.config.use_return_dict 955 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) --> 956 outputs = self.model( 957 input_ids=input_ids, 958 attention_mask=attention_mask, 959 position_ids=position_ids, 960 past_key_values=past_key_values, 961 inputs_embeds=inputs_embeds, 962 use_cache=use_cache, 963 output_attentions=output_attentions, 964 output_hidden_states=output_hidden_states, 965 return_dict=return_dict, 966 ) 968 hidden_states = outputs[0] 969 logits = self.lm_head(hidden_states)

File /data/phi3-small-venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, kwargs) 1530 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1531 else: -> 1532 return self._call_impl(args, kwargs)

File /data/phi3-small-venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, *kwargs) 1536 # If we don't have any hooks, we want to skip the rest of the logic in 1537 # this function, and just call forward. 1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1539 or _global_backward_pre_hooks or _global_backward_hooks 1540 or _global_forward_hooks or _global_forward_pre_hooks): -> 1541 return forward_call(args, **kwargs) 1543 try: 1544 result = None

File /data/phi3-small-venv/lib/python3.10/site-packages/accelerate/hooks.py:169, in add_hook_to_module..new_forward(module, *args, kwargs) 167 output = module._old_forward(*args, *kwargs) 168 else: --> 169 output = module._old_forward(args, kwargs) 170 return module._hf_hook.post_forward(module, output)

File ~/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-small-128k-instruct/f80aaa30bfc64c2b8ab214b541d9050e97163bc4/modeling_phi3_small.py:849, in Phi3SmallModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict) 846 all_hidden_states += (hidden_states,) 848 if self.gradient_checkpointing and self.training: --> 849 layer_outputs = self._gradient_checkpointing_func( 850 decoder_layer.call, 851 hidden_states, 852 attention_mask, 853 position_ids, 854 past_key_values, 855 output_attentions, 856 use_cache, 857 ) 858 else: 859 layer_outputs = decoder_layer( 860 hidden_states, 861 attention_mask=attention_mask, (...) 865 use_cache=use_cache, 866 )

File /data/phi3-small-venv/lib/python3.10/site-packages/torch/_compile.py:24, in _disable_dynamo..inner(*args, kwargs) 20 @functools.wraps(fn) 21 def inner(*args, *kwargs): 22 import torch._dynamo ---> 24 return torch._dynamo.disable(fn, recursive)(args, kwargs)

File /data/phi3-small-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:451, in _TorchDynamoContext.call.._fn(*args, *kwargs) 449 prior = set_eval_frame(callback) 450 try: --> 451 return fn(args, **kwargs) 452 finally: 453 set_eval_frame(prior)

File /data/phi3-small-venv/lib/python3.10/site-packages/torch/_dynamo/external_utils.py:36, in wrap_inline..inner(*args, kwargs) 34 @functools.wraps(fn) 35 def inner(*args, *kwargs): ---> 36 return fn(args, kwargs)

File /data/phi3-small-venv/lib/python3.10/site-packages/torch/utils/checkpoint.py:487, in checkpoint(function, use_reentrant, context_fn, determinism_check, debug, *args, kwargs) 482 if context_fn is not noop_context_fn or debug is not False: 483 raise ValueError( 484 "Passing context_fn or debug is only supported when " 485 "use_reentrant=False." 486 ) --> 487 return CheckpointFunction.apply(function, preserve, args) 488 else: 489 gen = _checkpoint_without_reentrant_generator( 490 function, preserve, context_fn, determinism_check, debug, args, kwargs 491 )

File /data/phi3-small-venv/lib/python3.10/site-packages/torch/autograd/function.py:598, in Function.apply(cls, *args, *kwargs) 595 if not torch._C._are_functorch_transforms_active(): 596 # See NOTE: [functorch vjp and autograd interaction] 597 args = _functorch.utils.unwrap_dead_wrappers(args) --> 598 return super().apply(args, **kwargs) # type: ignore[misc] 600 if not is_setup_ctx_defined: 601 raise RuntimeError( 602 "In order to use an autograd.Function with functorch transforms " 603 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context " 604 "staticmethod. For more details, please see " 605 "https://pytorch.org/docs/master/notes/extending.func.html" 606 )

File /data/phi3-small-venv/lib/python3.10/site-packages/torch/utils/checkpoint.py:262, in CheckpointFunction.forward(ctx, run_function, preserve_rng_state, args) 259 ctx.save_for_backward(tensor_inputs) 261 with torch.no_grad(): --> 262 outputs = run_function(*args) 263 return outputs

File /data/phi3-small-venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, kwargs) 1530 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1531 else: -> 1532 return self._call_impl(args, kwargs)

File /data/phi3-small-venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, *kwargs) 1536 # If we don't have any hooks, we want to skip the rest of the logic in 1537 # this function, and just call forward. 1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1539 or _global_backward_pre_hooks or _global_backward_hooks 1540 or _global_forward_hooks or _global_forward_pre_hooks): -> 1541 return forward_call(args, **kwargs) 1543 try: 1544 result = None

File /data/phi3-small-venv/lib/python3.10/site-packages/accelerate/hooks.py:169, in add_hook_to_module..new_forward(module, *args, kwargs) 167 output = module._old_forward(*args, *kwargs) 168 else: --> 169 output = module._old_forward(args, kwargs) 170 return module._hf_hook.post_forward(module, output)

File ~/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-small-128k-instruct/f80aaa30bfc64c2b8ab214b541d9050e97163bc4/modeling_phi3_small.py:671, in Phi3SmallDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, **kwargs) 668 hidden_states = self.input_layernorm(hidden_states) 670 # Self Attention --> 671 hidden_states, self_attn_weights, present_key_values = self.self_attn( 672 hidden_states=hidden_states, 673 attention_mask=attention_mask, 674 position_ids=position_ids, 675 past_key_values=past_key_values, 676 output_attentions=output_attentions, 677 use_cache=use_cache, 678 ) 679 hidden_states = residual + hidden_states 681 # Fully Connected

File /data/phi3-small-venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, kwargs) 1530 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1531 else: -> 1532 return self._call_impl(args, kwargs)

File /data/phi3-small-venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, *kwargs) 1536 # If we don't have any hooks, we want to skip the rest of the logic in 1537 # this function, and just call forward. 1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1539 or _global_backward_pre_hooks or _global_backward_hooks 1540 or _global_forward_hooks or _global_forward_pre_hooks): -> 1541 return forward_call(args, **kwargs) 1543 try: 1544 result = None

File /data/phi3-small-venv/lib/python3.10/site-packages/accelerate/hooks.py:169, in add_hook_to_module..new_forward(module, *args, kwargs) 167 output = module._old_forward(*args, *kwargs) 168 else: --> 169 output = module._old_forward(args, kwargs) 170 return module._hf_hook.post_forward(module, output)

File ~/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-small-128k-instruct/f80aaa30bfc64c2b8ab214b541d9050e97163bc4/modeling_phi3_small.py:624, in Phi3SmallSelfAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, **kwargs) 616 attn_function_output = self._apply_blocksparse_attention( 617 q=query_states, 618 k=expanded_key_states, (...) 621 return_attention_probs=output_attentions 622 ) 623 else: --> 624 attn_function_output = self._apply_dense_attention( 625 q=query_states, 626 k=expanded_key_states, 627 v=expanded_value_states, 628 attention_mask=attention_mask, 629 return_attention_probs=output_attentions 630 ) 632 attn_weights = None 633 if output_attentions:

File ~/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-small-128k-instruct/f80aaa30bfc64c2b8ab214b541d9050e97163bc4/modeling_phi3_small.py:441, in Phi3SmallSelfAttention._apply_dense_attention(self, q, k, v, attention_mask, return_attention_probs) 439 max_seqlen_q, max_seqlen_k = max_seq_lens 440 flat_kv = torch.cat((flat_k.unsqueeze(1), flat_v.unsqueeze(1)), dim=1) --> 441 attn_output_unpad = flash_attn_varlen_kvpacked_func( 442 q=flat_q, 443 kv=flat_kv, 444 cu_seqlens_q=cu_seqlens_q, 445 cu_seqlens_k=cu_seqlens_k, 446 max_seqlen_q=max_seqlen_q, 447 max_seqlen_k=max_seqlen_k, 448 dropout_p=attention_dropout_prob, 449 softmax_scale=self.softmax_scale, 450 causal=causal, 451 return_attn_probs=return_attention_probs 452 ) 453 attention_output = pad_input( 454 attn_output_unpad, indices_q, batch_size, query_length 455 ) 456 else:

File /data/phi3-small-venv/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py:978, in flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_attn_probs) 907 def flash_attn_varlen_kvpacked_func( 908 q, 909 kv, (...) 920 return_attn_probs=False, 921 ): 922 """dropout_p should be set to 0.0 during evaluation 923 If K, V are already stacked into 1 tensor, this function will be faster than 924 calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation (...) 976 pattern (negative means that location was dropped, nonnegative means it was kept). 977 """ --> 978 return FlashAttnVarlenKVPackedFunc.apply( 979 q, 980 kv, 981 cu_seqlens_q, 982 cu_seqlens_k, 983 max_seqlen_q, 984 max_seqlen_k, 985 dropout_p, 986 softmax_scale, 987 causal, 988 window_size, 989 alibi_slopes, 990 deterministic, 991 return_attn_probs, 992 )

File /data/phi3-small-venv/lib/python3.10/site-packages/torch/autograd/function.py:598, in Function.apply(cls, *args, *kwargs) 595 if not torch._C._are_functorch_transforms_active(): 596 # See NOTE: [functorch vjp and autograd interaction] 597 args = _functorch.utils.unwrap_dead_wrappers(args) --> 598 return super().apply(args, **kwargs) # type: ignore[misc] 600 if not is_setup_ctx_defined: 601 raise RuntimeError( 602 "In order to use an autograd.Function with functorch transforms " 603 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context " 604 "staticmethod. For more details, please see " 605 "https://pytorch.org/docs/master/notes/extending.func.html" 606 )

File /data/phi3-small-venv/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py:432, in FlashAttnVarlenKVPackedFunc.forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax) 430 if softmax_scale is None: 431 softmax_scale = q.shape[-1] ** (-0.5) --> 432 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( 433 q, 434 kv[:, 0], 435 kv[:, 1], 436 cu_seqlens_q, 437 cu_seqlens_k, 438 max_seqlen_q, 439 max_seqlen_k, 440 dropout_p, 441 softmax_scale, 442 causal=causal, 443 window_size=window_size, 444 alibi_slopes=alibi_slopes, 445 return_softmax=return_softmax and dropout_p > 0, 446 block_table=None, 447 ) 448 ctx.save_for_backward( 449 q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state 450 ) 451 ctx.dropout_p = dropout_p

File /data/phi3-small-venv/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py:86, in _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax, block_table) 84 maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x 85 q, k, v = [maybe_contiguous(x) for x in (q, k, v)] ---> 86 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( 87 q, 88 k, 89 v, 90 None, 91 cu_seqlens_q, 92 cu_seqlens_k, 93 None, 94 block_table, 95 alibi_slopes, 96 max_seqlen_q, 97 max_seqlen_k, 98 dropout_p, 99 softmax_scale, 100 False, 101 causal, 102 window_size[0], 103 window_size[1], 104 return_softmax, 105 None, 106 ) 107 # if out.isnan().any() or softmax_lse.isnan().any(): 108 # breakpoint() 109 return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state

RuntimeError: FlashAttention only support fp16 and bf16 data type

Expected/desired behavior

Training doesn't start currently, ideally should start

OS and Version?

Linux 6.5.0-45-generic

azd version?

None, not using azd

Versions

numpy==1.26.4 torch==2.3.0 ninja==1.11.1.1 transformers==4.41.1 bitsandbytes==0.41.3.post1 tiktoken==0.6.0 triton==2.3.0 flash-attn==2.5.8

Mention any other details that might be useful


Thanks! We'll be in touch soon.

leestott commented 1 month ago

@ArpitSharma7 can you confirm which sample from the cookbook you were using? If this is a general issues please log in on the hugging face discussions

ArpitSharma7 commented 1 month ago

@leestott Its from this notebook. Finetuning/Phi-3-finetune-qlora-python.ipynb However this notebook is based on Phi-3-mini, I just replaced the model with phi-3-small-128k-instruct

skytin1004 commented 1 month ago

Hi @ArpitSharma7,

The error might be due to compatibility issues with the torch version. This problem is discussed in this issue. Have you checked if the torch version you are using is 2.2 or below?

ArpitSharma7 commented 1 month ago

@skytin1004 I have torch 2.3.0 installed because in the huggingface model page for phi-3-small-128K, triton 2.3.0 is a requirement which is compatible with only pytorch 2.3.0. In the link you mentioned they were using some nightly version of 2.3.0, I have the stable version. Not sure if that is the problem. Also if you are able to run Qlora Finetuning script with Phi3-small-128K model, let me know the library versions that you have, to confirm whether it is an environment issue

skytin1004 commented 1 month ago

@ArpitSharma7 I've recently reconfigured my environment to use torch version 2.3.1 and confirmed that flash-attention2 works well with the following setup. If flash-attention2 still does not work properly in your environment, I recommend using eager instead.

Environment Setup:

!pip install torch==2.3.1
!pip install bitsandbytes==0.43.1
!pip install transformers==4.4.1
!pip install peft==0.12.0
!pip install accelerate==0.33.0
!pip install datasets==2.19.1
!pip install trl==0.8.6
!pip install flash_attn==2.6.3

GPU: A100

CUDA Version: 12.1.105


If flash-attention2 does not work:

Use eager instead of flash-attention2 by replacing:

if torch.cuda.is_bf16_supported():
  compute_dtype = torch.bfloat16
  attn_implementation = 'flash_attention_2'

with

if torch.cuda.is_bf16_supported():
  compute_dtype = torch.bfloat16
  attn_implementation = 'eager'
superctj commented 5 days ago
model_id = "microsoft/Phi-3-small-8k-instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    trust_remote_code=True,
)

I am able to get around the error by passing in torch_dtype and attn_implementation when initiating the model (assuming using the transformers library).

skytin1004 commented 3 days ago
model_id = "microsoft/Phi-3-small-8k-instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    trust_remote_code=True,
)

I am able to get around the error by passing in torch_dtype and attn_implementation when initiating the model (assuming using the transformers library).

Hi @superctj,

Thank you for sharing your solution. Currently, the guide uses the following logic to determine which type and attention implementation to use:

if torch.cuda.is_bf16_supported():
    compute_dtype = torch.bfloat16
    attn_implementation = 'flash_attention_2'
else:
    compute_dtype = torch.float16
    attn_implementation = 'sdpa'

This logic checks if the GPU supports bfloat16 and sets the attention implementation to flash_attention_2 accordingly. If bfloat16 is not supported, it falls back to using float16 and sdpa.

It seems that passing torch_dtype andattn_implementation directly when initializing the model works well in your case. Do you have any recommendations for improving the current logic, especially for handling different GPU configurations?

superctj commented 2 days ago

Hi @skytin1004, I had problems running Phi-3-small-8k-instruct and Phi-3.5-mini-instruct from the transformers library on A40 GPU (see a similar issue here). After fixing this error, I saw warnings of numeric differences without using the flash attention. So I followed the instructions in Hugging Face documentation to enable the flash attention.