MeetKai / functionary

Chat language model that can use tools and interpret the results
MIT License
1.36k stars 107 forks source link

Deepspeed Stage 3 CPU Offload Training Method #233

Open heesuju opened 1 month ago

heesuju commented 1 month ago

Hello,

I have been trying to fine-tune functionary-2.5-small with my own custom dataset according to the provided test format.

I only have 24 GB of VRAM available, so I trained with deepspeed using stage 3 cpu offloading.

The training params are as follows:

deepspeed functionary/train/train_lora.py \
    --model_name_or_path meetkai/functionary-small-v2.5 \
    --train_data_path dataset/train.jsonl \
    --eval_data_path dataset/test.jsonl \
    --q_lora False \
    --bf16 True \
    --output_dir models/1.0 \
    --num_train_epochs 10 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --eval_accumulation_steps 1 \
    --evaluation_strategy "epoch" \
    --eval_steps 1000 \
    --save_strategy "steps" \
    --save_steps 1000 \
    --save_total_limit 3 \
    --logging_steps 1 \
    --learning_rate 1e-04 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --tf32 True \
    --model_max_length 4096 \
    --gradient_checkpointing True \
    --report_to wandb \
    --packing \
    --deepspeed functionary/train/ds_config/zero3_offload.json

The following is the ds_config I used:

{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "bf16": {
        "enabled": "auto"
    },
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": true
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

image

After the training, the following appeared in my checkpoint directory, but I could not merge the adapter weights into the model because of an error saying that the size is different from the current model. I tried using the 'merge_lora_weight.py' and got the following error message.

RuntimeError: Error(s) in loading state_dict for PeftModelForCausalLM:
        size mismatch for base_model.model.model.embed_tokens.modules_to_save.default.weight: copying a param with shape torch.Size([128261, 4096]) from checkpoint, the shape in current model is torch.Size([128256, 4096]).
        size mismatch for base_model.model.lm_head.modules_to_save.default.weight: copying a param with shape torch.Size([128261, 4096]) from checkpoint, the shape in current model is torch.Size([128256, 4096]).

Could you tell me what I am doing wrong?

Thank you for your time!

khai-meetkai commented 1 month ago

Hi @heesuju, I think the reason might be: you didn't pass the argument: --prompt_template_version v2.llama3 Because the base model is: meetkai/functionary-small-v2.5, which is using prompt template version: v2.llama3 while the default prompt template is: v2 - a mismatch. The v2.llama3 doesn't add new tokens while v2 adds new tokens --> Shape mismatch when you merged. So please run the training again with adding: --prompt_template_version v2.llama3