huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.36k stars 876 forks source link

[FSDP] Training with H100 is much slower than A100 #2675

Closed dmammfl closed 1 month ago

dmammfl commented 2 months ago

System Info

- `Accelerate` version: 0.30.0.dev0
- Platform: Linux-5.15.0-87-generic-x86_64-with-glibc2.35
- `accelerate` bash location: /home/work/.local/anaconda3/envs/multinode-test/bin/accelerate
- Python version: 3.10.14
- Numpy version: 1.26.4
- PyTorch version (GPU?): 2.2.0+cu121 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- System RAM: 1007.36 GB
- GPU type: CUDA GPU
- `Accelerate` default config:
        Not found

Information

Tasks

Reproduction

I'm training Llama-2 model with FSDP on 4xH100(80G) and 4xA100(80G) server respectively.

7B model full fine-tuning and 70B PEFT work as expected(H100 is faster than A100), but when it comes to 70B FT, H100 training time is 1.5~2x slower than A100 server.

Here is total training time for A100 and H100 server.

A100 H100
Llama 7B Full FT 11m 40s 21m 10s
Llama 70B PEFT 48m 15s 1h 45m
Llama 70B Full FT 23h 17m 15h 12m

Although there are many differences between two distinct servers, but it is weird that only 70B full fine-tuning has the problem. Do you have any idea for this situation?

My training codes are below

import os, torch, jsonlines

from dataclasses import dataclass, field
from datasets import Dataset, concatenate_datasets
from typing import Optional
from peft import LoraConfig, get_peft_model

from transformers import set_seed, HfArgumentParser, TrainingArguments, AutoTokenizer, AutoModelForCausalLM, \
    BitsAndBytesConfig, EarlyStoppingCallback, Trainer, DataCollatorForSeq2Seq

@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={
            "help": "Path to pretrained model or model identifier from huggingface.co/models"
        }
    )
    lora_alpha: Optional[int] = field(default=32)
    lora_dropout: Optional[float] = field(default=0.1)
    lora_r: Optional[int] = field(default=16)
    lora_target_modules: Optional[str] = field(
        default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj",
        metadata={
            "help": "comma separated list of target modules to apply LoRA layers to"
        },
    )
    use_nested_quant: Optional[bool] = field(
        default=True,
        metadata={"help": "Activate nested quantization for 4bit base models"},
    )
    bnb_4bit_compute_dtype: Optional[str] = field(
        default="bfloat16",
        metadata={"help": "Compute dtype for 4bit base models"},
    )
    bnb_4bit_quant_type: Optional[str] = field(
        default="nf4",
        metadata={"help": "Quantization type fp4 or nf4"},
    )
    use_flash_attn: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables Flash attention for training."},
    )
    use_peft_lora: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables PEFT LoRA for training."},
    )
    use_8bit_quantization: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables loading model in 8bit."},
    )
    use_4bit_quantization: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables loading model in 4bit."},
    )
    use_reentrant: Optional[bool] = field(
        default=False,
        metadata={"help": "Gradient Checkpointing param. Refer the related docs"},
    )
    dataset_name: Optional[str] = field(
        default="timdettmers/openassistant-guanaco",
        metadata={"help": "The preference dataset to use."},
    )
    max_input_seq_length: Optional[int] = field(default=-1)
    max_output_length: Optional[int] = field(default=-1)
    train_test_split_ratio: Optional[float] = field(default=0.1)
    early_stopping_patience: Optional[int] = field(default=3)

def prepare_dataset_causal_lm(train_path, tokenizer, inputTokenLength, outputTokenLength, splitRatio):
    source_text, target_text = [], []
    chunk_source_text, chunk_target_text = [], []
    chunk_size = 1000
    is_first = True
    with jsonlines.open(train_path) as f:
        for line in f.iter():
            if is_first:
                is_first = False
            else:
                source_text.append(line['source'])
                target_text.append(line['target'])
                if len(source_text) % chunk_size == 0:
                    chunk_source_text.append(source_text)
                    chunk_target_text.append(target_text)
                    source_text = []
                    target_text = []

    chunk_source_text.append(source_text)
    chunk_target_text.append(target_text)

    source_length, target_length = 0, 0
    source_token_cnt, target_token_cnt = 0, 0

    for i in range(len(chunk_source_text)):
        source_text = chunk_source_text[i]
        target_text = chunk_target_text[i]
        # source_encoding_np = tokenizer.batch_encode_plus(train_all['train']['source'], padding=False, truncation=False)
        source_longest = tokenizer.batch_encode_plus(source_text, padding=False, truncation=False)
        # target_encoding_np = tokenizer.batch_encode_plus(train_all['test']['target'], padding=False, truncation=False)
        target_longest = tokenizer.batch_encode_plus(target_text, padding=False, truncation=False)
        for i in range(len(source_longest['input_ids'])):
            source_token_cnt += len(source_longest['input_ids'][i])
            if len(source_longest['input_ids'][i]) > source_length:
                source_length = len(source_longest['input_ids'][i])
            target_token_cnt += len(target_longest['input_ids'][i])
            if len(target_longest['input_ids'][i]) > target_length:
                target_length = len(target_longest['input_ids'][i])

    if inputTokenLength == -1:
        source_length += 2
    else:
        source_length = inputTokenLength

    if outputTokenLength == -1:
        target_length += 2
    else:
        target_length = outputTokenLength

    print('info > Source Token Cnt:', source_token_cnt, ' Source Max Length:', source_length)
    print('info > Target Token Cnt:', target_token_cnt, ' Target Max Length:', target_length)

    # logger.debug("End Check longest token len")

    def _tokenize(batch):
        source_encoding = tokenizer.batch_encode_plus(batch['source'])['input_ids']
        example = [source + ' 요약해줘\n' + target + tokenizer.eos_token for source, target in
                   zip(batch['source'], batch['target'])]
        example_encoding = tokenizer.batch_encode_plus(example, padding='longest', truncation=True,
                                                       max_length=source_length + target_length)

        input_encoding, attention_encoding = example_encoding['input_ids'], example_encoding['attention_mask']
        label_encoding = [[-100] * len(source) + example[len(source):] for source, example in
                          zip(source_encoding, input_encoding)]

        return {'input_ids': input_encoding, 'attention_mask': attention_encoding, 'labels': label_encoding}

    # logger.debug("Start Tokenizing map")
    for i in range(len(chunk_source_text)):
        source_text = chunk_source_text[i]
        target_text = chunk_target_text[i]

        train_all = Dataset.from_dict({'source': source_text, 'target': target_text})
        train_all = train_all.train_test_split(test_size=splitRatio)

        train_all_encoded = train_all.map(_tokenize, batched=True, batch_size=None)
        train_all_encoded.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

        if i == 0:
            print('info > DataSample')
            print(train_all['train'][0:3])
            concat_train_all = train_all
            concat_train_all_encoded = train_all_encoded

        else:
            concat_train_all['train'] = concatenate_datasets([concat_train_all['train'], train_all['train']])
            concat_train_all['test'] = concatenate_datasets([concat_train_all['test'], train_all['test']])
            concat_train_all_encoded['train'] = concatenate_datasets(
                [concat_train_all_encoded['train'], train_all_encoded['train']])
            concat_train_all_encoded['test'] = concatenate_datasets(
                [concat_train_all_encoded['test'], train_all_encoded['test']])

    return {'source_token_cnt': source_token_cnt, 'source_length': source_length, 'target_token_cnt': target_token_cnt,
            'target_length': source_length + target_length, 'train_all': concat_train_all,
            'train_all_encoded': concat_train_all_encoded}

def main(model_args, training_args):
    #Step 1. Set Seed
    set_seed(training_args.seed)

    #Step 2. Set Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    #Step 3. Prepare Dataset
    prepare_data = prepare_dataset_causal_lm(model_args.dataset_name, tokenizer,
                                             model_args.max_input_seq_length, model_args.max_output_length, model_args.train_test_split_ratio)

    print('Dataset Prepared')
    print(' - Source Max Length:', prepare_data['source_length'], 'Target Max Length:', prepare_data['target_length'])
    print(' - Train Data Num:', len(prepare_data['train_all_encoded']['train']), 'Eval Data Num:',
          len(prepare_data['train_all_encoded']['test']))

    #Step 4. Prepare Model
    device_map, bnb_config = None, None
    if model_args.use_4bit_quantization or model_args.use_8bit_quantization:
        compute_dtype = getattr(torch, model_args.bnb_4bit_compute_dtype)

        bnb_config = BitsAndBytesConfig(
            load_in_4bit=model_args.use_4bit_quantization,
            load_in_8bit=model_args.use_8bit_quantization,
            bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_use_double_quant=model_args.use_nested_quant,
        )
        device_map = (
            int(os.environ.get("LOCAL_RANK", -1))
            if torch.distributed.is_available() and torch.distributed.is_initialized()
            else "auto"
        )  # {"": 0}

    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        quantization_config=bnb_config,
        device_map=device_map,
        torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float16,
        trust_remote_code=True,
        attn_implementation="flash_attention_2" if model_args.use_flash_attn else "eager",
    )

    # model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)

    # Step 4. Prepare Model
    if model_args.use_peft_lora:
        peft_config = LoraConfig(
            lora_alpha=model_args.lora_alpha,
            lora_dropout=model_args.lora_dropout,
            r=model_args.lora_r,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=model_args.lora_target_modules.split(",")
            if model_args.lora_target_modules != "all-linear"
            else model_args.lora_target_modules,
        )
        model = get_peft_model(model, peft_config)

    # gradient ckpt
    model.config.use_cache = not training_args.gradient_checkpointing

    #Step 5. Training
    if training_args.gradient_checkpointing:
        training_args.gradient_checkpointing_kwargs = {
            "use_reentrant": model_args.use_reentrant
        }

    seq2seq_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=seq2seq_data_collator,
        train_dataset=prepare_data['train_all_encoded']['train'],
        eval_dataset=prepare_data['train_all_encoded']['test'],
        tokenizer=tokenizer,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=model_args.early_stopping_patience)]
    )

    trainer.accelerator.print(f"{trainer.model}")
    if model_args.use_peft_lora:
        # handle PEFT+FSDP case
        trainer.model.print_trainable_parameters()
        if getattr(trainer.accelerator.state, "fsdp_plugin", None):
            from peft.utils.other import fsdp_auto_wrap_policy

            fsdp_plugin = trainer.accelerator.state.fsdp_plugin
            fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(trainer.model)

    #Step 5. Training
    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    trainer.train(resume_from_checkpoint=checkpoint)

if __name__ == "__main__":
    parser = HfArgumentParser(
        (ModelArguments, TrainingArguments)
    )

    model_args, training_args = parser.parse_args_into_dataclasses()
    main(model_args, training_args)

and training scripts are below

accelerate launch --config_file configs/fsdp_config.yaml \
finetuning_llama_with_hf_trainer.py \
--seed 7 \
--model_name "meta-llama/Llama-2-70b-hf" \
--dataset_name "data/KLUE-TC.jsonl" \
--max_input_seq_length -1 \
--max_output_length -1 \
--num_train_epochs 3 \
--logging_steps 5 \
--log_level "info" \
--logging_strategy "steps" \
--evaluation_strategy "epoch" \
--save_strategy "epoch" \
--report_to "none" \
--bf16 True \
--learning_rate 2e-5 \
--lr_scheduler_type "cosine" \
--weight_decay 0.01 \
--warmup_ratio 0.1 \
--max_grad_norm 1.0 \
--save_total_limit 3 \
--output_dir "outputs/fsdp/full-ft-llama-70b-klue-tc" \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 16 \
--gradient_accumulation_steps 1 \
--gradient_checkpointing True \
--load_best_model_at_end True \
--use_reentrant False \
--use_flash_attn True \
--ddp_timeout 5400 \
--optim paged_adamw_32bit \

fsdp configurations

compute_environment: LOCAL_MACHINE                                                                                                                           
debug: false                                                                                                                                                 
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
main_training_function: main
mixed_precision: bf16
main_process_ip: 172.17.0.5
main_process_port: 6000
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Expected behavior

H100 should be faster than A100 for full fine-tuning 70B model.

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.