huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.31k stars 26.35k forks source link

Training CodeLLaMa-7b with FSDP causes loss 0 error #27121

Closed TomasAndersonFang closed 10 months ago

TomasAndersonFang commented 10 months ago

System Info

- `transformers` version: 4.34.0
- Platform: Linux-4.18.0-477.21.1.el8_8.x86_64-x86_64-with-glibc2.28
- Python version: 3.10.12
- Huggingface_hub version: 0.17.3
- Safetensors version: 0.4.0
- Accelerate version: 0.23.0
- Accelerate config:    not found
- PyTorch version (GPU?): 2.0.1+cu117 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

My script:

# coding=utf-8
# Implements parameter-efficient or full parameters supervised fine-tuning for LLaMa model.
# This code is inspired by
# https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py and https://www.mlexpert.io/machine-learning/tutorials/alpaca-fine-tuning

import transformers
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    DataCollatorForSeq2Seq,
    Trainer,
    Seq2SeqTrainer,
    HfArgumentParser,
    Seq2SeqTrainingArguments,
    BitsAndBytesConfig,
)

from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
    prepare_model_for_kbit_training,
    set_peft_model_state_dict,
)

import torch
import os
import evaluate
import functools
from datasets import load_dataset
# import bitsandbytes as bnb
import logging
import json
import copy
from typing import Dict, Optional, Sequence
from dataclasses import dataclass, field

# Lora settings
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT= 0.05
LORA_TARGET_MODULES = [
    "q_proj",
    "v_proj",
]

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="elinas/llama-7b-hf-transformers-4.29")

@dataclass
class DataArguments:
    data_path: str = field(default=None, metadata={"help": "Path to the training data."})
    train_file: str = field(default=None, metadata={"help": "Path to the evaluation data."})
    eval_file: str = field(default=None, metadata={"help": "Path to the evaluation data."})
    cache_path: str = field(default=None, metadata={"help": "Path to the cache directory."})
    num_proc: int = field(default=4, metadata={"help": "Number of processes to use for data preprocessing."})

@dataclass
class TrainingArguments(transformers.TrainingArguments):
    # cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    # adam_beta1: float = field(default=0.9)
    # adam_beta2: float = field(default=0.95)
    model_max_length: int = field(
        default=1024,
        metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
    )
    is_lora: bool = field(default=True, metadata={"help": "Whether to use LORA."})

def tokenize(text, tokenizer, max_seq_len=1024, add_eos_token=True):
    result = tokenizer(
        text,
        truncation=False,
        max_length=max_seq_len,
        padding=False,
        return_tensors=None,
    )

    # If the tokenized length exceeds the max_seq_len, return None
    if len(result["input_ids"]) >= max_seq_len:
        return None

    if (
        result["input_ids"][-1] != tokenizer.eos_token_id
        and len(result["input_ids"]) < max_seq_len
        and add_eos_token
    ):
        result["input_ids"].append(tokenizer.eos_token_id)
        result["attention_mask"].append(1)

    # if add_eos_token and len(result["input_ids"]) >= max_seq_len:
    #     result["input_ids"][max_seq_len - 1] = tokenizer.eos_token_id
    #     result["attention_mask"][max_seq_len - 1] = 1

    result["labels"] = result["input_ids"].copy()
    return result

def main():
    parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if training_args.is_lora:
        model = AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=data_args.cache_path,
            torch_dtype=torch.float16,
            trust_remote_code=True,
            load_in_8bit=True,
            quantization_config=BitsAndBytesConfig(
                load_in_8bit=True,
                llm_int8_threshold=6.0
            ),
        )
        model = prepare_model_for_kbit_training(model)

        config = LoraConfig(
            r=LORA_R,
            lora_alpha=LORA_ALPHA,
            target_modules=LORA_TARGET_MODULES,
            lora_dropout=LORA_DROPOUT,
            bias="none",
            task_type="CAUSAL_LM",
        )
        model = get_peft_model(model, config)
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            torch_dtype=torch.float16,
            cache_dir=data_args.cache_path,
            trust_remote_code=True,
        )
    model.config.use_cache = False

    def print_trainable_parameters(model):
        """
        Prints the number of trainable parameters in the model.
        """
        trainable_params = 0
        all_param = 0
        for _, param in model.named_parameters():
            all_param += param.numel()
            if param.requires_grad:
                trainable_params += param.numel()
        print(
            f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
        )
    if training_args.is_lora:
        print_trainable_parameters(model)

    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=data_args.cache_path,
        model_max_length=training_args.model_max_length,
        padding_side="left",
        trust_remote_code=True,
        use_fast=True,
    )
    tokenizer.pad_token = tokenizer.unk_token

    # Load dataset

    def generate_and_tokenize_prompt(sample):

        input_text = sample["input"]
        target_text = sample["output"] + tokenizer.eos_token
        full_text = input_text + target_text

        tokenized_full_text = tokenize(full_text, tokenizer, max_seq_len=training_args.model_max_length)

        if tokenized_full_text is None:
            # Return a null sample if the tokenized length exceeds the max_seq_len
            return {"input_ids": [], "attention_mask": [], "labels": []}

        tokenized_input_text = tokenize(input_text, tokenizer, max_seq_len=training_args.model_max_length)
        input_len = len(tokenized_input_text["input_ids"]) # This a bug of llamatokenizer that it does not add eos token
        tokenized_full_text["labels"] = [-100] * input_len + tokenized_full_text["labels"][input_len:]
        return tokenized_full_text

    data_files = {}
    if data_args.train_file is not None:
        data_files["train"] = data_args.train_file
    if data_args.eval_file is not None:
        data_files["eval"] = data_args.eval_file

    dataset = load_dataset(data_args.data_path, data_files=data_files)
    train_dataset = dataset["train"]
    eval_dataset = dataset["eval"]

    def print_dataset_length(dataset, name):
        print(f"Number of samples in {name} dataset after filtering: {len(dataset)}")

    train_dataset = train_dataset.map(generate_and_tokenize_prompt, num_proc=data_args.num_proc)
    eval_dataset = eval_dataset.map(generate_and_tokenize_prompt, num_proc=data_args.num_proc)
    # Filter null samples
    train_dataset = train_dataset.filter(lambda sample: len(sample["input_ids"]) > 0)
    eval_dataset = eval_dataset.filter(lambda sample: len(sample["input_ids"]) > 0)

    print_dataset_length(train_dataset, "train")
    print_dataset_length(eval_dataset, "eval")

    data_collator = DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True)

    # Evaluation metrics
    def compute_metrics(eval_preds, tokenizer):
        metric = evaluate.load('exact_match')
        preds, labels = eval_preds
        # In case the model returns more than the prediction logits
        if isinstance(preds, tuple):
            preds = preds[0]

        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=False)

        # Replace -100s in the labels as we can't decode them
        labels[labels == -100] = tokenizer.pad_token_id
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=False)

        # Some simple post-processing
        decoded_preds = [pred.strip() for pred in decoded_preds]
        decoded_labels = [label.strip() for label in decoded_labels]

        result = metric.compute(predictions=decoded_preds, references=decoded_labels)
        return {'exact_match': result['exact_match']} 

    compute_metrics_fn = functools.partial(compute_metrics, tokenizer=tokenizer)

    # Training
    trainer = Trainer(
        model=model, 
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,  
        args=training_args,
        data_collator=data_collator,
        compute_metrics=compute_metrics_fn,
    )
    trainer.train()
    trainer.save_state()
    trainer.save_model(output_dir=training_args.output_dir)
    tokenizer.save_pretrained(save_directory=training_args.output_dir)

if __name__ == "__main__":
    main()

Commands used to launch the script:

accelerate launch --config_file "/proj/berzelius-2023-175/users/x_senfa/apr_ft/configs/fsdp_config.yaml" /proj/berzelius-2023-175/users/x_senfa/apr_ft/llama2_sft.py \
    --model_name_or_path  \
    --data_path  \
    --output_dir  \
    --train_file train_data.jsonl \
    --eval_file test_data.jsonl \
    --is_lora False \
    --model_max_length 1024 \
    --cache_path  \
    --do_train \
    --do_eval False \
    --fp16 True \
    --bf16 False \
    --num_train_epochs 2 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 2 \
    --evaluation_strategy "no" \
    --eval_steps 10 \
    --save_steps 1200 \
    --learning_rate 5e-4 \
    --lr_scheduler_type "cosine" \
    --logging_steps 10 \

Accelerate config

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_forward_prefetch: true
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Log

  0%|          | 0/11084 [00:00<?, ?it/s]You're using a CodeLlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a CodeLlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a CodeLlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a CodeLlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.

  0%|          | 1/11084 [00:01<5:18:57,  1.73s/it]
  0%|          | 2/11084 [00:03<4:29:55,  1.46s/it]
  0%|          | 3/11084 [00:04<5:01:43,  1.63s/it]
  0%|          | 4/11084 [00:07<5:40:13,  1.84s/it]
  0%|          | 5/11084 [00:09<5:55:33,  1.93s/it]
  0%|          | 6/11084 [00:11<6:24:14,  2.08s/it]
  0%|          | 7/11084 [00:14<6:56:29,  2.26s/it]
  0%|          | 8/11084 [00:17<7:43:48,  2.51s/it]
  0%|          | 9/11084 [00:20<8:13:53,  2.68s/it]
  0%|          | 10/11084 [00:23<8:49:02,  2.87s/it]

{'loss': 3.8593, 'learning_rate': 0.0004999999899580808, 'epoch': 0.0}

  0%|          | 10/11084 [00:23<8:49:02,  2.87s/it]
  0%|          | 11/11084 [00:26<8:42:57,  2.83s/it]
  0%|          | 12/11084 [00:27<7:31:00,  2.44s/it]
  0%|          | 13/11084 [00:30<7:26:25,  2.42s/it]
  0%|          | 14/11084 [00:32<7:08:13,  2.32s/it]
  0%|          | 15/11084 [00:34<6:53:57,  2.24s/it]
  0%|          | 16/11084 [00:36<6:44:50,  2.19s/it]
  0%|          | 17/11084 [00:38<6:21:09,  2.07s/it]
  0%|          | 18/11084 [00:40<6:25:54,  2.09s/it]
  0%|          | 19/11084 [00:43<7:26:38,  2.42s/it]
  0%|          | 20/11084 [00:45<6:57:20,  2.26s/it]

{'loss': 0.0, 'learning_rate': 0.0004999999899580808, 'epoch': 0.0}

  0%|          | 20/11084 [00:45<6:57:20,  2.26s/it]
  0%|          | 21/11084 [00:48<7:25:37,  2.42s/it]
  0%|          | 22/11084 [00:50<7:26:41,  2.42s/it]
  0%|          | 23/11084 [00:53<7:55:51,  2.58s/it]
  0%|          | 24/11084 [00:56<8:18:13,  2.70s/it]
  0%|          | 25/11084 [00:58<8:03:54,  2.63s/it]
  0%|          | 26/11084 [01:01<7:34:18,  2.47s/it]
  0%|          | 27/11084 [01:03<7:44:08,  2.52s/it]
  0%|          | 28/11084 [01:06<7:58:40,  2.60s/it]
  0%|          | 29/11084 [01:09<8:11:45,  2.67s/it]
  0%|          | 30/11084 [01:11<7:26:06,  2.42s/it]

{'loss': 0.0, 'learning_rate': 0.0004999999899580808, 'epoch': 0.01}

  0%|          | 30/11084 [01:11<7:26:06,  2.42s/it]
  0%|          | 31/11084 [01:13<7:27:43,  2.43s/it]
  0%|          | 32/11084 [01:16<7:56:56,  2.59s/it]
  0%|          | 33/11084 [01:19<7:50:14,  2.55s/it]
  0%|          | 34/11084 [01:21<7:46:59,  2.54s/it]
  0%|          | 35/11084 [01:26<9:37:39,  3.14s/it]
  0%|          | 36/11084 [01:27<8:27:02,  2.75s/it]
  0%|          | 37/11084 [01:31<8:48:13,  2.87s/it]
  0%|          | 38/11084 [01:33<8:05:37,  2.64s/it]
  0%|          | 39/11084 [01:35<7:53:22,  2.57s/it]
  0%|          | 40/11084 [01:39<8:54:45,  2.91s/it]

{'loss': 0.0, 'learning_rate': 0.0004999999899580808, 'epoch': 0.01}

  0%|          | 40/11084 [01:39<8:54:45,  2.91s/it]
  0%|          | 41/11084 [01:41<8:38:49,  2.82s/it]
  0%|          | 42/11084 [01:43<7:15:42,  2.37s/it]
  0%|          | 43/11084 [01:46<7:53:19,  2.57s/it]
  0%|          | 44/11084 [01:49<8:11:36,  2.67s/it]
  0%|          | 45/11084 [01:52<8:23:12,  2.74s/it]
  0%|          | 46/11084 [01:54<8:18:13,  2.71s/it]
  0%|          | 47/11084 [01:56<7:49:05,  2.55s/it]
  0%|          | 48/11084 [01:58<7:11:24,  2.35s/it]
  0%|          | 49/11084 [02:01<7:18:50,  2.39s/it]
  0%|          | 50/11084 [02:03<7:17:48,  2.38s/it]

{'loss': 0.0, 'learning_rate': 0.0004999999899580808, 'epoch': 0.01}

  0%|          | 50/11084 [02:03<7:17:48,  2.38s/it]
  0%|          | 51/11084 [02:06<7:48:02,  2.55s/it]
  0%|          | 52/11084 [02:08<7:40:46,  2.51s/it]
  0%|          | 53/11084 [02:11<7:34:10,  2.47s/it]
  0%|          | 54/11084 [02:13<7:42:31,  2.52s/it]
  0%|          | 55/11084 [02:16<8:00:33,  2.61s/it]
  1%|          | 56/11084 [02:20<8:42:51,  2.84s/it]
  1%|          | 57/11084 [02:23<9:09:47,  2.99s/it]
  1%|          | 58/11084 [02:26<9:00:00,  2.94s/it]
  1%|          | 59/11084 [02:28<8:21:44,  2.73s/it]
  1%|          | 60/11084 [02:30<8:03:29,  2.63s/it]

{'loss': 0.0, 'learning_rate': 0.0004999999899580808, 'epoch': 0.01}

  1%|          | 60/11084 [02:30<8:03:29,  2.63s/it]
  1%|          | 61/11084 [02:33<8:15:02,  2.69s/it]
  1%|          | 62/11084 [02:36<8:11:34,  2.68s/it]
  1%|          | 63/11084 [02:38<8:04:32,  2.64s/it]
  1%|          | 64/11084 [02:40<6:46:28,  2.21s/it]
  1%|          | 65/11084 [02:43<7:35:48,  2.48s/it]
  1%|          | 66/11084 [02:45<7:01:26,  2.30s/it]
  1%|          | 67/11084 [02:47<7:11:32,  2.35s/it]
  1%|          | 68/11084 [02:50<7:23:52,  2.42s/it]
  1%|          | 69/11084 [02:54<8:48:35,  2.88s/it]
  1%|          | 70/11084 [02:57<9:28:29,  3.10s/it]

{'loss': 0.0, 'learning_rate': 0.0004999999899580808, 'epoch': 0.01}

Expected behavior

I don't know why loss converges to 0 so quickly, so I think these may have some problems.

Additional info:

My question:

amyeroberts commented 10 months ago

cc @pacman100 @muellerz

TomasAndersonFang commented 10 months ago

@amyeroberts I'm sorry I actually solved this problem. This problem is caused by fp16 and a large learning rate. When fine-tuning LLaMA with Lora, it's ok to use them. But with full-parameter fine-tuning, it's necessary to use bf16 and a smaller learning rate (I use 5e-6, although 5e-5 is also ok but it's sometimes unstable).

amyeroberts commented 10 months ago

@TomasAndersonFang thanks for replying and detailing what the issue was!