UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.
return fn(*args, *kwargs)
/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:600: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.
return fn(args, **kwargs)
Traceback (most recent call last):
rank1: Traceback (most recent call last):
rank1: File "/home/work/testdataset1/LLaVA-NeXT/llava/train/train_xformers.py", line 13, in
rank1: File "/home/work/testdataset1/LLaVA-NeXT/llava/train/train_5img.py", line 1672, in train
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train
rank1: return inner_training_loop(
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop
rank1: tr_loss_step = self.training_step(model, inputs)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 3318, in training_step
rank1: loss = self.compute_loss(model, inputs)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 3363, in compute_loss
rank1: outputs = model(inputs)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
rank1: return self._call_impl(*args, *kwargs)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
rank1: return forward_call(args, kwargs)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
rank1: ret_val = func(*args, kwargs)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1846, in forward
rank1: loss = self.module(*inputs, *kwargs)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
rank1: return self._call_impl(args, kwargs)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
rank1: result = forward_call(*args, kwargs)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/peft/peft_model.py", line 1577, in forward
rank1: return self.base_model(
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
rank1: return self._call_impl(*args, *kwargs)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
rank1: result = forward_call(args, kwargs)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 188, in forward
rank1: return self.model.forward(*args, kwargs)
rank1: File "/home/work/testdataset1/LLaVA-NeXT/llava/model/language_model/llava_llama.py", line 109, in forward
rank1: return super().forward(
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1139, in forward
rank1: outputs = self.model(
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
rank1: return self._call_impl(*args, *kwargs)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
rank1: result = forward_call(args, kwargs)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 930, in forward
rank1: layer_outputs = self._gradient_checkpointing_func(
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_compile.py", line 31, in inner
rank1: return disable_fn(*args, kwargs)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
rank1: return fn(*args, kwargs)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint
rank1: return CheckpointFunction.apply(function, preserve, args)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
rank1: return super().apply(args, kwargs) # type: ignoremisc: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 255, in forward
rank1: outputs = run_function(args)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
rank1: return self._call_impl(args, *kwargs)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
rank1: result = forward_call(args, kwargs)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 677, in forward
rank1: hidden_states, self_attn_weights, present_key_value = self.self_attn(
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
rank1: return self._call_impl(*args, kwargs)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
rank1: result = forward_call(*args, *kwargs)
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 500, in forward
rank1: attn_output = _flash_attention_forward(
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 214, in _flash_attention_forward
rank1: attn_output = flash_attn_func(
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 880, in flash_attn_func
rank1: return FlashAttnFunc.apply(
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
rank1: return super().apply(args, kwargs) # type: ignoremisc: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 546, in forward
rank1: out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
rank1: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 52, in _flash_attn_forward
rank1: out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
rank1: RuntimeError: FlashAttention only supports Ampere GPUs or newer.
File "/home/work/testdataset1/LLaVA-NeXT/llava/train/train_xformers.py", line 13, in
train()
File "/home/work/testdataset1/LLaVA-NeXT/llava/train/train_5img.py", line 1672, in train
trainer.train()
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train
return inner_training_loop(
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 3318, in training_step
loss = self.compute_loss(model, inputs)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 3363, in compute_loss
outputs = model(inputs)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(args, kwargs)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, kwargs)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1846, in forward
loss = self.module(*inputs, *kwargs)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(args, kwargs)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
result = forward_call(*args, kwargs)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/peft/peft_model.py", line 1577, in forward
return self.base_model(
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
result = forward_call(args, kwargs)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 188, in forward
return self.model.forward(*args, kwargs)
File "/home/work/testdataset1/LLaVA-NeXT/llava/model/language_model/llava_llama.py", line 109, in forward
return super().forward(
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1139, in forward
outputs = self.model(
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
result = forward_call(args, kwargs)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 930, in forward
layer_outputs = self._gradient_checkpointing_func(
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_compile.py", line 31, in inner
return disable_fn(*args, kwargs)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
return fn(*args, kwargs)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint
return CheckpointFunction.apply(function, preserve, args)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
return super().apply(args, kwargs) # type: ignore[misc]
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 255, in forward
outputs = run_function(args)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(args, *kwargs)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
result = forward_call(args, kwargs)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 677, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, kwargs)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
result = forward_call(*args, *kwargs)
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 500, in forward
attn_output = _flash_attention_forward(
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 214, in _flash_attention_forward
attn_output = flash_attn_func(
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 880, in flash_attn_func
return FlashAttnFunc.apply(
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
return super().apply(args, kwargs) # type: ignore[misc]
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 546, in forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 52, in _flash_attn_forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: FlashAttention only supports Ampere GPUs or newer.
rank0: Traceback (most recent call last):
rank0: File "/home/work/testdataset1/LLaVA-NeXT/llava/train/train_xformers.py", line 13, in
rank0: File "/home/work/testdataset1/LLaVA-NeXT/llava/train/train_5img.py", line 1672, in train
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train
rank0: return inner_training_loop(
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop
rank0: tr_loss_step = self.training_step(model, inputs)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 3318, in training_step
rank0: loss = self.compute_loss(model, inputs)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 3363, in compute_loss
rank0: outputs = model(inputs)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
rank0: return self._call_impl(*args, *kwargs)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
rank0: return forward_call(args, kwargs)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
rank0: ret_val = func(*args, kwargs)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1846, in forward
rank0: loss = self.module(*inputs, *kwargs)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
rank0: return self._call_impl(args, kwargs)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
rank0: result = forward_call(*args, kwargs)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/peft/peft_model.py", line 1577, in forward
rank0: return self.base_model(
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
rank0: return self._call_impl(*args, *kwargs)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
rank0: result = forward_call(args, kwargs)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 188, in forward
rank0: return self.model.forward(*args, kwargs)
rank0: File "/home/work/testdataset1/LLaVA-NeXT/llava/model/language_model/llava_llama.py", line 109, in forward
rank0: return super().forward(
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1139, in forward
rank0: outputs = self.model(
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
rank0: return self._call_impl(*args, *kwargs)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
rank0: result = forward_call(args, kwargs)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 930, in forward
rank0: layer_outputs = self._gradient_checkpointing_func(
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_compile.py", line 31, in inner
rank0: return disable_fn(*args, kwargs)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
rank0: return fn(*args, kwargs)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint
rank0: return CheckpointFunction.apply(function, preserve, args)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
rank0: return super().apply(args, kwargs) # type: ignoremisc: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 255, in forward
rank0: outputs = run_function(args)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
rank0: return self._call_impl(args, *kwargs)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
rank0: result = forward_call(args, kwargs)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 677, in forward
rank0: hidden_states, self_attn_weights, present_key_value = self.self_attn(
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
rank0: return self._call_impl(*args, kwargs)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
rank0: result = forward_call(*args, *kwargs)
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 500, in forward
rank0: attn_output = _flash_attention_forward(
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 214, in _flash_attention_forward
rank0: attn_output = flash_attn_func(
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 880, in flash_attn_func
rank0: return FlashAttnFunc.apply(
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
rank0: return super().apply(args, kwargs) # type: ignoremisc: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 546, in forward
rank0: out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
rank0: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 52, in _flash_attn_forward
rank0: out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
rank0: RuntimeError: FlashAttention only supports Ampere GPUs or newer.
I send my appreciation to "mylesgoose" for providing with me his pretrained model.
Since my server environment does not seem to support Ampere GPU, I have been trying to disable Flash attention.
First, I simply brought the train_xformers.py and llama_xformers_attn_monkey_patch.py files to my directory so I can use xformers instead of train_mem.py on LLaVA-NeXT as well.
Second, I removed 'attn_implementation' argument to completely disable the usage of Flash attention.
However, none of the settings help me not face "RuntimeError: FlashAttention only supports Ampere GPUs or newer." issue.
Does anyone know how else I can try disabling Flash attention?
I will share with you my shell file for finetyning siglip A4 below.
LLM_VERSION="mylesgoose/Meta-Llama-3.1-8B-Instruct-goose-abliterated" LLM_VERSION_CLEAN="${LLM_VERSION////}" VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION////}"
############### Pretrain ################
PROMPT_VERSION=plain
BASE_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-mlp2x_gelu-pretrain_blip558k_plain" echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
CKPT_PATH=$LLM_VERSION
Error message:
I send my appreciation to "mylesgoose" for providing with me his pretrained model.