LLaVA-VL / LLaVA-NeXT

Apache License 2.0
2.88k stars 244 forks source link

How do we turn off Flash attention in LLaVA-NeXT? #224

Open Bleking opened 2 months ago

Bleking commented 2 months ago

Since my server environment does not seem to support Ampere GPU, I have been trying to disable Flash attention.

First, I simply brought the train_xformers.py and llama_xformers_attn_monkey_patch.py files to my directory so I can use xformers instead of train_mem.py on LLaVA-NeXT as well.

Second, I removed 'attn_implementation' argument to completely disable the usage of Flash attention.

However, none of the settings help me not face "RuntimeError: FlashAttention only supports Ampere GPUs or newer." issue.

Does anyone know how else I can try disabling Flash attention?

I will share with you my shell file for finetyning siglip A4 below.

LLM_VERSION="mylesgoose/Meta-Llama-3.1-8B-Instruct-goose-abliterated" LLM_VERSION_CLEAN="${LLM_VERSION////}" VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION////}"

############### Pretrain ################

PROMPT_VERSION=plain

BASE_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-mlp2x_gelu-pretrain_blip558k_plain" echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"

CKPT_PATH=$LLM_VERSION

deepspeed llava/train/train_xformers.py \
    --lora_enable True --lora_r 16 --lora_alpha 256 --mm_projector_lr 2e-5 \
    --deepspeed scripts/zero3_offload_new.json \
    --model_name_or_path ${CKPT_PATH} \
    --version ${PROMPT_VERSION} \
    --data_path ./playground/floorplan_vqa_1000.json \
    --image_folder /home/work/testdataset1/LLaVA/playground/data/floorplan_data/ \
    --pretrain_mm_mlp_adapter="/home/work/testdataset1/LLaVA-NeXT/checkpoints/projectors/llavanext-google_siglip-so400m-patch14-384-mylesgoose_Meta-Llama-3.1-8B-Instruct-goose-abliterated-mlp2x_gelu-pretrain_blip558k_plain/checkpoint-1500/mm_projector.bin" \
    --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \
    --mm_vision_tower_lr=2e-6 \
    --vision_tower ${VISION_MODEL_VERSION} \
    --mm_projector_type mlp2x_gelu \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --group_by_modality_length True \
    --image_aspect_ratio anyres \
    --image_grid_pinpoints "[(384, 768), (768, 384), (768, 768), (1152, 384), (384, 1152)]" \
    --mm_patch_merge_type spatial_unpad \
    --fp16 True \
    --bf16 False \
    --output_dir "./checkpoints/Meta-Llama-3.1-8B-Instruct-goose-abliterated-pre" \
    --num_train_epochs 1 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 16 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 500 \
    --save_total_limit 2 \
    --learning_rate 1e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 False \
    --model_max_length 1024 \
    --gradient_checkpointing True \
    --dataloader_num_workers 0 \
    --lazy_preprocess True \
    --report_to wandb \
    --torch_compile False \
    --torch_compile_backend "inductor" \
    --dataloader_drop_last True \
    --run_name llavanext-siglip-400m-Meta-Llama-3.1-8B-pretrain_blip558k_plain \

Error message:

UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants. return fn(*args, *kwargs) /home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:600: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants. return fn(args, **kwargs) Traceback (most recent call last): rank1: Traceback (most recent call last): rank1: File "/home/work/testdataset1/LLaVA-NeXT/llava/train/train_xformers.py", line 13, in

rank1: File "/home/work/testdataset1/LLaVA-NeXT/llava/train/train_5img.py", line 1672, in train

rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train rank1: return inner_training_loop( rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop rank1: tr_loss_step = self.training_step(model, inputs) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 3318, in training_step rank1: loss = self.compute_loss(model, inputs) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 3363, in compute_loss rank1: outputs = model(inputs) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl rank1: return self._call_impl(*args, *kwargs) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl rank1: return forward_call(args, kwargs) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn rank1: ret_val = func(*args, kwargs) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1846, in forward rank1: loss = self.module(*inputs, *kwargs) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl rank1: return self._call_impl(args, kwargs) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl rank1: result = forward_call(*args, kwargs) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/peft/peft_model.py", line 1577, in forward rank1: return self.base_model( rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl rank1: return self._call_impl(*args, *kwargs) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl rank1: result = forward_call(args, kwargs) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 188, in forward rank1: return self.model.forward(*args, kwargs) rank1: File "/home/work/testdataset1/LLaVA-NeXT/llava/model/language_model/llava_llama.py", line 109, in forward rank1: return super().forward( rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1139, in forward rank1: outputs = self.model( rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl rank1: return self._call_impl(*args, *kwargs) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl rank1: result = forward_call(args, kwargs) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 930, in forward rank1: layer_outputs = self._gradient_checkpointing_func( rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_compile.py", line 31, in inner rank1: return disable_fn(*args, kwargs) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn rank1: return fn(*args, kwargs) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint rank1: return CheckpointFunction.apply(function, preserve, args) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply rank1: return super().apply(args, kwargs) # type: ignoremisc: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 255, in forward rank1: outputs = run_function(args) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl rank1: return self._call_impl(args, *kwargs) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl rank1: result = forward_call(args, kwargs) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 677, in forward rank1: hidden_states, self_attn_weights, present_key_value = self.self_attn( rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl rank1: return self._call_impl(*args, kwargs) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl rank1: result = forward_call(*args, *kwargs) rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 500, in forward rank1: attn_output = _flash_attention_forward( rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 214, in _flash_attention_forward rank1: attn_output = flash_attn_func( rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 880, in flash_attn_func rank1: return FlashAttnFunc.apply( rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply rank1: return super().apply(args, kwargs) # type: ignoremisc: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 546, in forward rank1: out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 52, in _flash_attn_forward rank1: out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( rank1: RuntimeError: FlashAttention only supports Ampere GPUs or newer. File "/home/work/testdataset1/LLaVA-NeXT/llava/train/train_xformers.py", line 13, in train() File "/home/work/testdataset1/LLaVA-NeXT/llava/train/train_5img.py", line 1672, in train trainer.train() File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train return inner_training_loop( File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop tr_loss_step = self.training_step(model, inputs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 3318, in training_step loss = self.compute_loss(model, inputs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 3363, in compute_loss outputs = model(inputs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(args, kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn ret_val = func(*args, kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1846, in forward loss = self.module(*inputs, *kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl result = forward_call(*args, kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/peft/peft_model.py", line 1577, in forward return self.base_model( File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl result = forward_call(args, kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 188, in forward return self.model.forward(*args, kwargs) File "/home/work/testdataset1/LLaVA-NeXT/llava/model/language_model/llava_llama.py", line 109, in forward return super().forward( File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1139, in forward outputs = self.model( File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl result = forward_call(args, kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 930, in forward layer_outputs = self._gradient_checkpointing_func( File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_compile.py", line 31, in inner return disable_fn(*args, kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn return fn(*args, kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint return CheckpointFunction.apply(function, preserve, args) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply return super().apply(args, kwargs) # type: ignore[misc] File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 255, in forward outputs = run_function(args) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(args, *kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl result = forward_call(args, kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 677, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl result = forward_call(*args, *kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 500, in forward attn_output = _flash_attention_forward( File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 214, in _flash_attention_forward attn_output = flash_attn_func( File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 880, in flash_attn_func return FlashAttnFunc.apply( File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply return super().apply(args, kwargs) # type: ignore[misc] File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 546, in forward out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 52, in _flash_attn_forward out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( RuntimeError: FlashAttention only supports Ampere GPUs or newer. rank0: Traceback (most recent call last): rank0: File "/home/work/testdataset1/LLaVA-NeXT/llava/train/train_xformers.py", line 13, in

rank0: File "/home/work/testdataset1/LLaVA-NeXT/llava/train/train_5img.py", line 1672, in train

rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train rank0: return inner_training_loop( rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop rank0: tr_loss_step = self.training_step(model, inputs) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 3318, in training_step rank0: loss = self.compute_loss(model, inputs) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 3363, in compute_loss rank0: outputs = model(inputs) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl rank0: return self._call_impl(*args, *kwargs) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl rank0: return forward_call(args, kwargs) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn rank0: ret_val = func(*args, kwargs) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1846, in forward rank0: loss = self.module(*inputs, *kwargs) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl rank0: return self._call_impl(args, kwargs) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl rank0: result = forward_call(*args, kwargs) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/peft/peft_model.py", line 1577, in forward rank0: return self.base_model( rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl rank0: return self._call_impl(*args, *kwargs) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl rank0: result = forward_call(args, kwargs) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 188, in forward rank0: return self.model.forward(*args, kwargs) rank0: File "/home/work/testdataset1/LLaVA-NeXT/llava/model/language_model/llava_llama.py", line 109, in forward rank0: return super().forward( rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1139, in forward rank0: outputs = self.model( rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl rank0: return self._call_impl(*args, *kwargs) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl rank0: result = forward_call(args, kwargs) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 930, in forward rank0: layer_outputs = self._gradient_checkpointing_func( rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_compile.py", line 31, in inner rank0: return disable_fn(*args, kwargs) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn rank0: return fn(*args, kwargs) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint rank0: return CheckpointFunction.apply(function, preserve, args) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply rank0: return super().apply(args, kwargs) # type: ignoremisc: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 255, in forward rank0: outputs = run_function(args) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl rank0: return self._call_impl(args, *kwargs) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl rank0: result = forward_call(args, kwargs) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 677, in forward rank0: hidden_states, self_attn_weights, present_key_value = self.self_attn( rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl rank0: return self._call_impl(*args, kwargs) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl rank0: result = forward_call(*args, *kwargs) rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 500, in forward rank0: attn_output = _flash_attention_forward( rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 214, in _flash_attention_forward rank0: attn_output = flash_attn_func( rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 880, in flash_attn_func rank0: return FlashAttnFunc.apply( rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply rank0: return super().apply(args, kwargs) # type: ignoremisc: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 546, in forward rank0: out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 52, in _flash_attn_forward rank0: out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( rank0: RuntimeError: FlashAttention only supports Ampere GPUs or newer.

I send my appreciation to "mylesgoose" for providing with me his pretrained model.