TinyLLaVA / TinyLLaVA_Factory

A Framework of Small-scale Large Multimodal Models
https://arxiv.org/abs/2402.14289
Apache License 2.0
587 stars 52 forks source link

Can Tiny Llama 1.5B lora finetune in one 24GB gpu? #23

Closed JalorOo closed 6 months ago

JalorOo commented 6 months ago

I try lora finetune, it seems tiny llama only support full finetune now.

baichuanzhou commented 6 months ago

Finetuning TinyLLaVA with Phi is possible with one 24GB GPU.

I used this script to finetune TinyLLaVA-3.1B with LoRA on the pokemon dataset.

Please replace the data path to yours

#!/bin/bash

# Assign the arguments to variables
DATA_PATH="/path/to/your/pokemon_blip_captions.json"
IMAGE_PATH="/path/to/your/data/"
OUTPUT_DIR="/path/to/your/TinyLLaVA-3.1B-lora"

deepspeed tinyllava/train/train.py \
    --deepspeed ./scripts/tiny_llava/zero3.json \
    --lora_enable True --lora_r 32 --lora_alpha 64 \
    --model_name_or_path bczhou/TinyLLaVA-3.1B \
    --version phi \
    --data_path $DATA_PATH \
    --image_folder $IMAGE_PATH\
    --vision_tower bczhou/TinyLLaVA-3.1B-SigLIP \
    --mm_projector_type mlp2x_gelu \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --image_aspect_ratio pad \
    --group_by_modality_length False \
    --fp16 True \
    --output_dir $OUTPUT_DIR \
    --num_train_epochs 3 \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 2 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 50000 \
    --save_total_limit 1 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 False \
    --model_max_length 3072 \
    --gradient_checkpointing True \
    --dataloader_num_workers 15 \
    --lazy_preprocess True \
    --report_to wandb \

I used this script to convert the pokemon dataset:

import shortuuid
from datasets import load_dataset
from PIL import Image
import random
import json
import tqdm
import os

ds = load_dataset('lambdalabs/pokemon-blip-captions')
pokemon_data = []

pokemon_image_path = '/path/to/your/data/pokemon/image'
pokemon_data_path = '/path/to/your/pokemon_blip_captions.json'

description_list = [
    "Describe the image concisely.",
    "Provide a brief description of the given image.",
    "Offer a succinct explanation of the picture presented.",
    "Summarize the visual content of the image.",
    "Give a short and clear explanation of the subsequent image.",
    "Share a concise interpretation of the image provided.",
    "Present a compact description of the photo's key features.",
    "Relay a brief, clear account of the picture shown.",
    "Render a clear and concise summary of the photo.",
    "Write a terse but informative summary of the picture.",
    "Create a compact narrative representing the image presented."
]

for sample in tqdm.tqdm(ds['train']):
    uuid = shortuuid.uuid()
    sample_dict = dict()
    sample_dict['id'] = uuid
    sample_dict['image'] = 'pokemon/image/' + uuid + '.jpg'
    sample['image'].save(os.path.join(pokemon_image_path, uuid + '.jpg'))
    conversations = [
        {"from": "human", "value": "<image>\n" + random.choice(description_list)},
        {"from": "gpt", "value": sample['text']}
    ]
    sample_dict['conversations'] = conversations
    pokemon_data.append(sample_dict)

with open(pokemon_data_path, 'w') as f:
    json.dump(pokemon_data, f, indent=4)

The experiment took approximately 10 minutes with one 4090. Let me know if you have anymore questions! I am writing a tutorial doc on finetuning with LoRA, if you are interested, please contribute.

JalorOo commented 6 months ago

Finetuning TinyLLaVA with Phi is possible with 24GB gpu.

I used this script to finetune TinyLLaVA-3.1B with LoRA on the pokemon dataset.

Please replace the data path to yours


#!/bin/bash

# Assign the arguments to variables

DATA_PATH="/path/to/your/pokemon_blip_captions.json"

IMAGE_PATH="/path/to/your/data/"

OUTPUT_DIR="/path/to/your/TinyLLaVA-3.1B-lora"

deepspeed tinyllava/train/train.py \

    --deepspeed ./scripts/tiny_llava/zero3.json \

    --lora_enable True --lora_r 32 --lora_alpha 64 \

    --model_name_or_path bczhou/TinyLLaVA-3.1B \

    --version phi \

    --data_path $DATA_PATH \

    --image_folder $IMAGE_PATH\

    --vision_tower bczhou/TinyLLaVA-3.1B-SigLIP \

    --mm_projector_type mlp2x_gelu \

    --mm_vision_select_layer -2 \

    --mm_use_im_start_end False \

    --mm_use_im_patch_token False \

    --image_aspect_ratio pad \

    --group_by_modality_length False \

    --fp16 True \

    --output_dir $OUTPUT_DIR \

    --num_train_epochs 3 \

    --per_device_train_batch_size 8 \

    --per_device_eval_batch_size 4 \

    --gradient_accumulation_steps 2 \

    --evaluation_strategy "no" \

    --save_strategy "steps" \

    --save_steps 50000 \

    --save_total_limit 1 \

    --learning_rate 2e-5 \

    --weight_decay 0. \

    --warmup_ratio 0.03 \

    --lr_scheduler_type "cosine" \

    --logging_steps 1 \

    --tf32 False \

    --model_max_length 3072 \

    --gradient_checkpointing True \

    --dataloader_num_workers 15 \

    --lazy_preprocess True \

    --report_to wandb \

I used this script to convert the pokemon dataset:


import shortuuid

from datasets import load_dataset

from PIL import Image

import random

import json

import tqdm

import os

ds = load_dataset('lambdalabs/pokemon-blip-captions')

pokemon_data = []

pokemon_image_path = '/path/to/your/data/pokemon/image'

pokemon_data_path = '/path/to/your/pokemon_blip_captions.json'

description_list = [

    "Describe the image concisely.",

    "Provide a brief description of the given image.",

    "Offer a succinct explanation of the picture presented.",

    "Summarize the visual content of the image.",

    "Give a short and clear explanation of the subsequent image.",

    "Share a concise interpretation of the image provided.",

    "Present a compact description of the photo's key features.",

    "Relay a brief, clear account of the picture shown.",

    "Render a clear and concise summary of the photo.",

    "Write a terse but informative summary of the picture.",

    "Create a compact narrative representing the image presented."

]

for sample in tqdm.tqdm(ds['train']):

    uuid = shortuuid.uuid()

    sample_dict = dict()

    sample_dict['id'] = uuid

    sample_dict['image'] = 'pokemon/image/' + uuid + '.jpg'

    sample['image'].save(os.path.join(pokemon_image_path, uuid + '.jpg'))

    conversations = [

        {"from": "human", "value": "<image>\n" + random.choice(description_list)},

        {"from": "gpt", "value": sample['text']}

    ]

    sample_dict['conversations'] = conversations

    pokemon_data.append(sample_dict)

with open(pokemon_data_path, 'w') as f:

    json.dump(pokemon_data, f, indent=4)

The experiment took approximately 10 minutes with one 4090.

Let me know if you have anymore questions! I am writing a tutorial doc on finetuning with LoRA, if you are interested, please contribute.

Ok, I will try it tomorrow. Thank you!

JalorOo commented 6 months ago

ok, i use the code, it works!

AleNunezArroyo commented 6 months ago

I was already able to do the training, with the guide. The files that it generates are the following: image

But using the reference codes in the repository, I get this error. image

image

What code should I use for the model that has the finetune? Thank you.

JalorOo commented 6 months ago

I was already able to do the training, with the guide. The files that it generates are the following: image

But using the reference codes in the repository, I get this error. image

image

What code should I use for the model that has the finetune? Thank you.

firstly, you need to merge the lora weight into the origin model. next, you need to make sure the folder name of your store merged weight have the ''lora" word. 2024-03-27 16-28-12 的屏幕截图 2024-03-27 16-29-55 的屏幕截图

JalorOo commented 6 months ago

more detail you can find my open source code https://github.com/Libv-Team/figlang2024

AleNunezArroyo commented 6 months ago

more detail you can find my open source code https://github.com/Libv-Team/figlang2024

Thank you very much for the information and for the repository, it is already working 😊

JalorOo commented 6 months ago

more detail you can find my open source code https://github.com/Libv-Team/figlang2024

Thank you very much for the information and for the repository, it is already working 😊

you are welcome! 😊