Luodian / Otter

🦦 Otter, a multi-modal model based on OpenFlamingo (open-sourced version of DeepMind's Flamingo), trained on MIMIC-IT and showcasing improved instruction-following and in-context learning ability.
https://otter-ntu.github.io/
MIT License
3.55k stars 242 forks source link

Question about training with FDSP config #239

Open vishaal27 opened 1 year ago

vishaal27 commented 1 year ago

Hi, thanks for your great work!

I have a question about training your model on the LADD split.

I have an A100 machine with 40GB VRAM. I use the following command for training:

export PYTHONPATH=.

accelerate launch --config_file=./pipeline/accelerate_configs/accelerate_config_fsdp.yaml \
pipeline/train/instruction_following.py \
--pretrained_model_name_or_path=luodian/OTTER-LLaMA7B-INIT \
--mimicit_path="./data/LADD_instructions.json" \
--images_path="./data/LA.json" \
--train_config_path="./data/LADD_train.json" \
--external_save_dir="./checkpoints" \
--batch_size=4 \
--num_epochs=9 \
--report_to_wandb \
--wandb_entity=vu27 \
--run_name=OTTER-LLaMA7B-TEST \
--wandb_project=OTTER-LLaMA7B-TEST \
--workers=8 \
--lr_scheduler=cosine \
--learning_rate=1e-5 \
--warmup_steps_ratio=0.01

I haven't changed the FDSP config, it looks like:

compute_environment: LOCAL_MACHINE
distributed_type: no
downcast_bf16: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: false
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
main_process_port: 20687

However when training on a single GPU, I run into OOM issues. I verified that the accelerate configs are working well (the training indeed is mixed precision). I saw that in one of your issues (https://github.com/Luodian/Otter/issues/182), you mentioned loading the model in lower precision, which I tried. I loaded the model, both with fp16 and bf16 but was wondering if for your training you used fp16/bf16 loading or loaded the model in full precision and relied on accelerate-autocasting? I am wondering because sometimes bf16 training can be more unstable compared to full-precision training (but we are also only training about 1/9th of the model parameters so I'm not sure if this would be an issue). Also wondering whether you were able to do training on a single GPU at all, or did you use multi-GPU setups for all your experiments and did not test single GPU runs at all? Also would you be able to provide more details on your final training run (batch-sizes used, mixed precision details, ddp/fsdp training, gradient accumulation/checkpointing, number of GPUs etc) so that I would have some reference guidelines for my training runs, your help would be much appreciated!

vishaal27 commented 1 year ago

Another question: For fine-tuning from Otter, would you still recommend using the same training pipeline, or would recommend using PEFT methods (I guess you have implemented LoRA support already), do you have a guideline on what would be the best setup for fine-tuning Otter?

Luodian commented 1 year ago

hi Vishaal, may I know if you are using one GPU with 40G to train? Can you decrease it to batch_size=1 or try withluodian/OTTER-MPT1B-RPJama-Init?

Luodian commented 1 year ago

hi Vishaal, may I know if you are using one GPU with 40G to train? Can you decrease it to batch_size=1 or try withluodian/OTTER-MPT1B-RPJama-Init?

I am not sure if one 40G GPU could work, I remember init the Otter-LLAMA7B would cost around 16G mem, and train with batch_size=1 is around 30-40G.

vishaal27 commented 1 year ago

Thanks @Luodian , I was finally able to run it on a larger compute node with 4 40GB GPUs with a batch size of 32, but it still was strange that I couldn't test it on one GPU even with half precision. What ended up working though was loading the model in bf16 and using bf16 for mixed precision training. However, for fine-tuning would you still recommend this or would you recommend LoRA fine-tuning? Also do you have a sample train config for fine-tuning with LoRA?

Luodian commented 1 year ago

We dont have a promising result for lora finetuning. We tried finetune perceiver + cross_x_attn + lora LLM. But dont get better results than finetuning perceiver + cross_x_attn.

If you want to lora finetune LLM, you should first convert otter init model to a lora version using https://github.com/Luodian/Otter/blob/main/otter/converting_otter_to_lora.py

Then you could directly load it without any other modification in training procedure.

Luodian commented 1 year ago

If loaded with lora LLM, you could see relevant logs to show how many params are LoRAed, etc.