haotian-liu / LLaVA

[NeurIPS'23 Oral] Visual Instruction Tuning (LLaVA) built towards GPT-4V level capabilities and beyond.
https://llava.hliu.cc
Apache License 2.0
20.04k stars 2.21k forks source link

[Usage] Failed for training with RuntimeError: weight should have at least three dimensions #381

Open lg123666 opened 1 year ago

lg123666 commented 1 year ago

Describe the issue

Issue:

Command:

python llava/train/train_mem.py \ --deepspeed scripts/zero3.json \ --model_name_or_path /mnt/pretrained_params/llava_llama2/chinese-alpaca-2-7b \ --version "llava_llama_2" \ --data_path /mnt/public_dataset/LLAVA/llava_instruct_150k.jsonl \ --vision_tower /mnt/files/models/llava/clip-vit-large-patch14-336 \ --pretrain_mm_mlp_adapter /mnt/llava_llama2/projector/llava-336px-pretrain-llama-2-7b-chat/mm_projector.bin \ --mm_vision_select_layer -2 \ --mm_use_im_start_end False \ --mm_use_im_patch_token False \ --bf16 True \ --output_dir ./checkpoints/ \ --num_train_epochs 3 \ --per_device_train_batch_size 2 \ --per_device_eval_batch_size 2 \ --gradient_accumulation_steps 1 \ --evaluation_strategy "no" \ --save_strategy "steps" \ --save_steps 50000 \ --save_total_limit 1 \ --learning_rate 2e-5 \ --weight_decay 0. \ --warmup_ratio 0.03 \ --lr_scheduler_type "cosine" \ --logging_steps 1 \ --tf32 True \ --model_max_length 2048 \ --gradient_checkpointing True \ --dataloader_num_workers 1 \ --lazy_preprocess True \ --report_to wandb --loader exllama Log:

File "/home/users/tom/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 2654, in training_step loss = self.compute_loss(model, inputs) File "/home/users/tom/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 2679, in compute_loss outputs = model(inputs) File "/home/users/tom/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/home/users/tom/miniconda3/envs/llava/lib/python3.10/site-packages/accelerate/utils/operations.py", line 581, in forward return model_forward(args, kwargs) File "/home/users/tom/miniconda3/envs/llava/lib/python3.10/site-packages/accelerate/utils/operations.py", line 569, in call return convert_to_fp32(self.model_forward(*args, kwargs)) File "/home/users/tom/miniconda3/envs/llava/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast return func(*args, *kwargs) File "/home/users/tom/projects/LLaVA/mllm/llava/model/language_model/llava_llama.py", line 75, in forward input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) File "/home/users/tom/projects/LLaVA/mllm/llava/model/llava_arch.py", line 102, in prepare_inputs_labels_for_multimodal image_features = self.encode_images(images) File "/home/users/tom/projects/LLaVA/mllm/llava/model/llava_arch.py", line 82, in encode_images image_features = self.get_model().get_vision_tower()(images) File "/home/users/tom/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/home/users/tom/miniconda3/envs/llava/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, kwargs) File "/home/users/tom/projects/LLaVA/mllm/llava/model/multimodal_encoder/clip_encoder.py", line 49, in forward image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) File "/home/users/tom/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/home/users/tom/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 941, in forward return self.vision_model( File "/home/users/tom/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/home/users/tom/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 866, in forward hidden_states = self.embeddings(pixel_values) File "/home/users/tom/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) File "/home/users/tom/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 195, in forward patch_embeds = self.patch_embedding(pixel_values) # shape = [, width, grid, grid] File "/home/users/tom/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/home/users/tom/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 463, in forward return self._conv_forward(input, self.weight, self.bias) File "/home/users/tom/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward return F.conv2d(input, weight, bias, self.stride, RuntimeError: weight should have at least three dimensions wandb: Waiting for W&B process to finish... (failed 1). wandb: You can sync this run to the cloud by running:ped) wandb: wandb sync /home/users/tom/projects/LLaVA/mllm/wandb/offline-run-20230821_104204-z0nnbr9q wandb: Find logs at: ./wandb/offline-run-20230821_104204-z0nnbr9q/logs

Screenshots: But I running clip-vit-large-patch14 using official code is OK image

Gregory1994 commented 1 year ago

same problem when using v100. But when i using A100, everything works well. I guess the problem made by deepspeed, which convert the clip model type to "meta" rather than bf16, and doesn't load the weight of clip.

lg123666 commented 1 year ago

same problem when using v100. But when i using A100, everything works well. I guess the problem made by deepspeed, which convert the clip model type to "meta" rather than bf16, and doesn't load the weight of clip.

Have you tried running successfully without deepspeed? I want to debug it, so it is crucial for me to get through one iteration.

1106301825 commented 1 year ago

Have you solved it yet? I am having the same problem.

1106301825 commented 1 year ago

@lg123666

lg123666 commented 1 year ago

@lg123666

"--deepspeed", "scripts/zero2.json" probably help you

if you want debug model forward,

replace_llama_attn_with_flash_attn() # or delete it

model.to('cpu') model.float() model.eval() model.model.mm_projector.to(model.device)