jondurbin / bagel

A bagel, with everything.
300 stars 31 forks source link

A bagel, with everything

The name of this project was shamelessly stolen from from Everything, Everywhere, All at Once.

bagel

Data selection.

The first step in the process is creating a dataset. In this case, we're actually creating a composite dataset, consisting of both supervised fine-tuning data (SFT) and direct preference optimization (DPO) data.

All instruction data, that is, data that is not plain text (like project Gutenberg and items from Cinematika) or DPO, is converted into ShareGPT format so it's easier to work with.

See the corresponding code in bagel/data_sources/*.py for full implementation for each data source.

Deduplication is done by creating a uuid v5 of the instruction/text, then only adding items not previously seen (where datasets are loaded in order of the confidence score I assign them). This means that if an instruction is in data source "Foo" with confidence 4 as well as in data source "Bar" with confidence score 2, only the entry from "Foo" will be taken.

SFT data sources

Only train splits are used, and a decontamination by cosine similarity is performed at the end as a sanity check against common benchmarks. If you don't know the difference between train and test, please learn.

DPO data sources

Only the train splits were used (if a split was provided), and an additional pass of decontamination is performed using approximate nearest neighbor search (via faiss).

Prompt formatting

In sticking with the theme of the bagel, I didn't want to use a single prompt format, so I used 4 - vicuna, llama-2, alpaca, and chat-ml (sorta). I also didn't want to randomly select a single prompt format for each item (hoping each instruction would generalize more when used in a variety of prompt formats), so each instruction is actually converted into every prompt format.

This means each epoch of our fine-tune is really basically 4 epochs. So, for the fine-tunes, I would recommend only doing 1 epoch (or 0.75 epochs). I am testing with a single epoch using a relatively low learning rate.

Alpaca (sort of)

Below is an instruction that describes a task.  Write a response that appropriately completes the request.

### Instruction:
{system prompt, if provided}
{instruction}

### Response:

The main difference here is that because of the dataset formatting and variety of data sources, it would have been much to tedious to add an ### Input: block, so the inputs are just in the instruction section.

Vicuna

{system prompt, if provided, randomly defaulting to "A chat between a user and an unbiased, uncensored assistant."}
USER: {instruction}
ASSISTANT: 

ChatML

This format is digital cancer, but it's common so I included it.

{bos}<|im_start|>{role}
{text}<|im_end|>

Llama-2 chat

[INST] <<SYS>>
{system}
<</SYS>>

{instruction} [/INST]

Fine-tuning

First, you need to prepare the dataset as input-output pairs for the SFT phase, and prompt/chosen/rejected for DPO:

python -m bagel.data

Then, you'll have a DPO parquet and SFT parquet, which you can use to build a model.

SFT phase

An example for mistral-7b:

Note: I actually used my fork of qlora's train.py for this, but I'm porting it to a minified version here, not tested yet!

export BASE_DIR=/workspace
export WANDB_API_KEY=[redacted]
export WANDB_PROJECT=bagel-7b-v0.1

# Run the pretraining.
accelerate launch bagel/tune/sft.py \
  --model_name_or_path $BASE_DIR/mistral-7b \
  --final_output_dir $BASE_DIR/$WANDB_PROJECT \
  --output_dir $BASE_DIR/$WANDB_PROJECT-workdir \
  --num_train_epochs 1 \
  --logging_steps 1 \
  --save_strategy steps \
  --save_steps 200 \
  --save_total_limit 5 \
  --data_seed 42 \
  --evaluation_strategy steps \
  --eval_dataset_size 0.0006 \
  --eval_steps 200 \
  --max_new_tokens 4096 \
  --dataloader_num_workers 3 \
  --logging_strategy steps \
  --remove_unused_columns False \
  --do_train \
  --full_finetune \
  --bf16 \
  --bits 16 \
  --optim adamw_torch \
  --lr_scheduler_type linear \
  --dataset $BASE_DIR/bagel/bagel-input-output-v0.1.parquet \
  --dataset_format input-output \
  --model_max_len 4096 \
  --per_device_train_batch_size 8 \
  --learning_rate 3.5e-7 \
  --warmup_ratio 0.005 \
  --adam_beta2 0.999 \
  --max_grad_norm 0.3 \
  --weight_decay 0.001 \
  --seed 42 \
  --report_to wandb \
  --gradient_checkpointing True \
  --gradient_accumulation_steps 4 \
  --skip_excess_length False \
  --ddp_find_unused_parameters False \
  --use_flash_attention_2 \
  --group_by_length True \
  --deepspeed deepspeed.json

Deepspeed configuration:

{
  "gradient_accumulation_steps": "auto",
  "gradient_clipping": "auto",
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "bf16": {
    "enabled": true
  },
  "zero_optimization": {
    "stage": 2,
    "contiguous_gradients": true,
    "overlap_comm": true,
    "reduce_scatter": true,
    "reduce_bucket_size": 5e8,
    "allgather_bucket_size": 5e8
  }
}

DPO phase

An example of the DPO phase for mistral-7b (requires first running the SFT):

export BASE_DIR=/mnt/data
export WANDB_API_KEY=[redacted]
export WANDB_PROJECT=bagel-dpo-7b-v0.1

accelerate launch bagel/tune/dpo.py \
  --model_name_or_path bagel-7b-v0.1 \
  --learning_rate 3e-7 \
  --per_device_train_batch_size 2 \
  --gradient_accumulation_steps 4 \
  --max_length 4096 \
  --max_prompt_length 1024 \
  --max_target_length 3092 \
  --num_train_epochs 3 \
  --report_to wandb \
  --gradient_checkpointing true \
  --use_flash_attention_2 true \
  --dataset $BASE_DIR/bagel/bagel-dpo-v0.1.parquet \
  --eval_steps 5 \
  --eval_dataset_size 0.03 \
  --workdir $BASE_DIR/$WANDB_PROJECT-workdir \
  --output_dir $BASE_DIR/$WANDB_PROJECT \
  --deepspeed deepspeed.json \
  --save_steps 25 \
  --save_total_limit 5