modelscope / ms-swift

Use PEFT or Full-parameter to finetune 400+ LLMs or 100+ MLLMs. (LLM: Qwen2.5, Llama3.2, GLM4, Internlm2.5, Yi1.5, Mistral, Baichuan2, DeepSeek, Gemma2, ...; MLLM: Qwen2-VL, Qwen2-Audio, Llama3.2-Vision, Llava, InternVL2, MiniCPM-V-2.6, GLM4v, Xcomposer2.5, Yi-VL, DeepSeek-VL, Phi3.5-Vision, ...)
https://swift.readthedocs.io/zh-cn/latest/Instruction/index.html
Apache License 2.0
4.34k stars 381 forks source link

Finetuning Qwen2VL yield error when enabling FlashAttention and mix text-only data #2246

Open VietDunghacker opened 1 month ago

VietDunghacker commented 1 month ago

Describe the bug When using Flash Attention (--use-flash-attention true) to train Qwen2VL model with mixed data (both image and text data), the code will yield the following error

[rank0]:   File "/home/***/ms-swift/swift/cli/sft.py", line 5, in <module>
[rank0]:     sft_main()
[rank0]:   File "/home/***/ms-swift/swift/utils/run_utils.py", line 32, in x_main
[rank0]:     result = llm_x(args, **kwargs)
[rank0]:   File "/home/***/ms-swift/swift/llm/sft.py", line 542, in llm_sft
[rank0]:     return trainer_train(args, model, template, train_dataset, val_dataset, callbacks=callbacks, msg=msg)
[rank0]:   File "/home/***/ms-swift/swift/llm/sft.py", line 492, in trainer_train
[rank0]:     trainer.train(training_args.resume_from_checkpoint)
[rank0]:   File "/home/***/ms-swift/swift/trainers/mixin.py", line 480, in train
[rank0]:     res = super().train(resume_from_checkpoint, *args, **kwargs)
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/transformers/trainer.py", line 2085, in train
[rank0]:     return inner_training_loop(
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/transformers/trainer.py", line 2421, in _inner_training_loop
[rank0]:     tr_loss_step = self.training_step(model, inputs)
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/transformers/trainer.py", line 3524, in training_step
[rank0]:     loss = self.compute_loss(model, inputs)
[rank0]:   File "/home/***/ms-swift/swift/trainers/trainers.py", line 161, in compute_loss
[rank0]:     outputs = model(**inputs)
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1899, in forward
[rank0]:     loss = self.module(*inputs, **kwargs)
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1582, in _call_impl
[rank0]:     args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
[rank0]:   File "/home/***/ms-swift/swift/llm/utils/template.py", line 344, in _pre_forward_hook
[rank0]:     res_extra.append(self._post_encode(module, d))
[rank0]:   File "/home/***/ms-swift/swift/llm/utils/template.py", line 1649, in _post_encode
[rank0]:     image_embeds = model.visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 1074, in forward
[rank0]:     hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/***/ms-swift/swift/llm/utils/utils.py", line 441, in _new_forward
[rank0]:     layer_ret = torch.utils.checkpoint.checkpoint(self.__old_forward, *args, **kwargs)
[rank0]:   File "/home/***/ms-swift/swift/llm/utils/model.py", line 6759, in <lambda>
[rank0]:     lambda *args, use_reentrant=_use_reentrant, **kwargs: _old_checkpoint(
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/torch/_compile.py", line 31, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 488, in checkpoint
[rank0]:     ret = function(*args, **kwargs)
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 431, in forward
[rank0]:     hidden_states = hidden_states + self.attn(
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 376, in forward
[rank0]:     attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 1124, in flash_attn_varlen_func
[rank0]:     return FlashAttnVarlenFunc.apply(
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 620, in forward
[rank0]:     out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
[rank0]:   File "/home/***/abc_env/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 90, in _flash_attn_varlen_forward
[rank0]:     out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
[rank0]: RuntimeError: cu_seqlens_q must be on CUDA

When I disabled flash-attention, the code ran smoothly. I also noticed that when I remove the text-only data and enable flash-attention, the code will not yield error. I believe the issue was mentioned in https://github.com/modelscope/ms-swift/issues/2147 and was fixed recently, but have you tested it with Flash Attention?

Your hardware and system info Write your system info like CUDA version/system/GPU/torch version here(在这里给出硬件信息和系统信息,如CUDA版本,系统,GPU型号和torch版本等) torch: 2.4.0 flash-attention: 2.6.3

Additional context Add any other context about the problem here(在这里补充其他信息)

Jintao-Huang commented 1 month ago

What version of transformers is it?

Jintao-Huang commented 1 month ago

I am using version 4.45.2, and it works fine.

VietDunghacker commented 1 month ago

I use the different version of transformers including 4.45.2, and the bug still occurred. For more context, I finetuned Qwen2VL with lora, targetting all possible linear layers (including vision model). The environment is 1xA100 80GB. And when I rolled back the code to the commit b654118003a963ef55b088aad44f834b54a6a641, the code ran smoothly. I believe there is something strange with post_encode method for Qwen2VL, as somehow the training time is doubled after I pull the lastest commit fixing some bug in that method.

NeosXu commented 1 week ago

I also encountered this issue, and I believe the cause of the problem might lie in the following lines of code: https://github.com/modelscope/ms-swift/blob/acd17e5a7d6a1f0073a48af164b1cf9ad5a1a561/swift/llm/utils/template.py#L1651-L1655 Because after I used the code below to move media_inputs['image_grid_thw'] to the device, the issue no longer occurred.

device = input_ids.device
pixel_values = media_inputs['pixel_values'].to(device)
image_grid_thw = media_inputs['image_grid_thw'].to(device)

pixel_values = pixel_values.type(model.visual.get_dtype())
image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
Jintao-Huang commented 1 week ago

thank you, i will check it

Jintao-Huang commented 1 week ago

Please tell me the version of accelerate.

NeosXu commented 1 week ago

accelerate: 1.1.1