Open ruian1 opened 4 weeks ago
试了下,4bit确实运行有问题,即使模型初始化时跳过检查,后续运行也不行,似乎是模型和4bit量化不兼容,8bit是可以的,是否可以使用8bit量化训练呢
试了下,4bit确实运行有问题,即使模型初始化时跳过检查,后续运行也不行,似乎是模型和4bit量化不兼容,8bit是可以的,是否可以使用8bit量化训练呢
谢谢,8bit可以用在internvl-chat-v1_5的训练上。我另一个要训练的是 internvl2-8b, 8bit也可以load,但是训练的时候会有另一个bug,
训练cmd
CUDA_VISIBLE_DEVICES=0,1,2,3 swift sft --model_type internvl2-8b --model_id_or_path /root/.cache/modelscope/hub/OpenGVLab/InternVL2-8B/ --dataset ./output/jsonl/train_dataset.jsonl --max_length 4096 --use_flash_attn true --gradient_checkpointing true --learning_rate 1e-6 --num_train_epochs=3 --gradient_accumulation_steps 64 --preprocess_num_proc 48 --quantization_bit 8 --dtype bf16
...
Train: 0%| | 0/7014 [00:00<?, ?it/s]/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
Traceback (most recent call last):
File "/root/projects/ms-swift/swift/cli/sft.py", line 5, in <module>
sft_main()
File "/root/projects/ms-swift/swift/utils/run_utils.py", line 32, in x_main
result = llm_x(args, **kwargs)
File "/root/projects/ms-swift/swift/llm/sft.py", line 405, in llm_sft
trainer.train(training_args.resume_from_checkpoint)
File "/root/projects/ms-swift/swift/trainers/mixin.py", line 538, in train
res = super().train(resume_from_checkpoint, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1948, in train
return inner_training_loop(
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2289, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 3328, in training_step
loss = self.compute_loss(model, inputs)
File "/root/projects/ms-swift/swift/trainers/trainers.py", line 179, in compute_loss
outputs = model(**inputs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1603, in _call_impl
result = forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 819, in forward
return model_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 807, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/peft/peft_model.py", line 1577, in forward
return self.base_model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/peft/tuners/tuners_utils.py", line 188, in forward
return self.model.forward(*args, **kwargs)
File "/root/projects/ms-swift/swift/llm/utils/model.py", line 4093, in wrapper
return forward_func(
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_internvl_chat.py", line 103, in forward
vit_embeds = self.extract_feature(pixel_values)
File "/root/projects/ms-swift/swift/llm/utils/model.py", line 4316, in _new_extract_feature
return extract_feature(pixel_values).to(pixel_values.device).to(pixel_values.dtype)
File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_internvl_chat.py", line 181, in extract_feature
vit_embeds = self.vision_model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_intern_vit.py", line 419, in forward
encoder_outputs = self.encoder(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_intern_vit.py", line 350, in forward
layer_outputs = torch.utils.checkpoint.checkpoint(
File "/root/projects/ms-swift/swift/llm/utils/model.py", line 6224, in <lambda>
lambda *args, use_reentrant=_use_reentrant, **kwargs: _old_checkpoint(
File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 31, in inner
return disable_fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 488, in checkpoint
ret = function(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_intern_vit.py", line 296, in forward
hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_intern_vit.py", line 252, in forward
x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_intern_vit.py", line 244, in _flash_attn
context, _ = self.inner_attn(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_intern_vit.py", line 66, in forward
assert qkv.dtype in [torch.float16, torch.bfloat16]
AssertionError
@tastelikefeet
8 bit 情况下的这个bug assert qkv.dtype in [torch.float16, torch.bfloat16] 找到原因了,提了一个huggingface的pr( https://huggingface.co/OpenGVLab/InternVL2-8B/discussions/13 )
Describe the bug What the bug is, and how to reproduce, better with screenshots(描述bug以及复现过程,最好有截图)
Goal is to fine tune internvl2-8b, using internvl-chat-v1_5 since it is an provided example
When finetuning internvl with cmd below to enable quantization(finetuning code without any quanzation parameters works). Same bug when fine-tuning internvl-8B
There is a bug below,
and it comes from line https://github.com/modelscope/ms-swift/blob/main/swift/llm/utils/model.py#L4265, the model.language_model.output is (output): Linear4bit(in_features=6144, out_features=92553, bias=False)
Your hardware and system info Write your system info like CUDA version/system/GPU/torch version here(在这里给出硬件信息和系统信息,如CUDA版本,系统,GPU型号和torch版本等)
Additional context Add any other context about the problem here(在这里补充其他信息)