huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.84k stars 26.76k forks source link

fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP is not working with the Trainer #34113

Open eljandoubi opened 1 week ago

eljandoubi commented 1 week ago

System Info

image

acc_cfg.yml:

compute_environment: LOCAL_MACHINE debug: false distributed_type: FSDP downcast_bf16: 'no' enable_cpu_affinity: true fsdp_config: fsdp_activation_checkpointing: true fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_backward_prefetch: NO_PREFETCH fsdp_cpu_ram_efficient_loading: true fsdp_forward_prefetch: true fsdp_offload_params: true fsdp_sharding_strategy: FULL_SHARD fsdp_state_dict_type: SHARDED_STATE_DICT fsdp_sync_module_states: true fsdp_use_orig_params: true machine_rank: 0 main_process_ip: 0.0.0.0 main_process_port: 0 main_training_function: main mixed_precision: bf16 num_machines: 3 num_processes: 24 rdzv_backend: etcd-v2 same_network: false tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false

Who can help?

No response

Information

Tasks

Reproduction

accelerate launch --config_file acc_cfg.yml train.py $TRAINING_ARGS the train.py is any training script that train using transformers.Trainer $TRAINING_ARGS are the TrainingArguments plus some path to data

fdsp_trans

Expected behavior

Train Paligemma model with FSDP and have PaliGemmaMultiModalProjector wrapped.

LysandreJik commented 1 week ago

cc @muellerzr and @SunMarc

muellerzr commented 1 week ago

Hi! Can you please show your entire script you are running?

eljandoubi commented 1 week ago
Training script for Vision2Seq model-like.
"""

import logging
from os import getenv
from typing import Union

from torch import cuda
from datasets import load_dataset, DatasetDict
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, set_seed, \
    AutoProcessor, AutoModelForVision2Seq, HfArgumentParser, EarlyStoppingCallback, \
    TrainingArguments, Trainer, BitsAndBytesConfig
from accelerate import PartialState#, logging
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from peft.optimizers import create_loraplus_optimizer

from utils.shared_dataclass import SharedArugments
import utils.args_parsers as parsers
import utils.data_processing as process
import utils.model_preparation as prep
from utils.post_process_args import post_process, get_sorted_key
from utils.eval_metrics import EvalMetrics

main_device = "cuda" if cuda.is_available() else "cpu"
main_state = PartialState(cpu=main_device=="cpu")

idx = main_state.process_index

logger = logging.getLogger(__name__)
#logger = logging.get_logger(__name__)

FORMAT = f'process {idx} %(levelname)s %(asctime)s,%(msecs)d %(message)s'

logging.basicConfig(level=logging.INFO,
                    format=FORMAT,
                    datefmt='%H:%M:%S')

def main():
    """Train Vision2Seq model-like"""

    logger.info("Start train script using \n%s",main_state)

    load_datasets_data_class = parsers.method_to_dataclass(load_dataset,
                                                           "load_dataset")

    early_stop_data_class = parsers.method_to_dataclass(EarlyStoppingCallback,
                                                        "EarlyStopping")
    train_type = getenv('TRAIN_TYPE')

    if train_type == "SEQ2SEQ":
        logger.info("Training type is SEQ2SEQ")
        train_args_class = Seq2SeqTrainingArguments
        trainer_class = Seq2SeqTrainer
    else:
        logger.info("Training type is %s",train_type)
        train_type = "CAUSAL"
        train_args_class = TrainingArguments
        trainer_class = Trainer

    dataclass_instances = [
        load_datasets_data_class,
        early_stop_data_class,
        SharedArugments,
        train_args_class
    ]

    if getenv("USE_LORA") == "True":
        logger.info("LORA is used.")
        use_lora =True
        lora_cfg = parsers.method_to_dataclass(LoraConfig,
                                               "LoraConfig",
                                               ["revision"])
        dataclass_instances.append(lora_cfg)
    else:
        logger.info("LORA is NOT used.")
        use_lora = False

    if getenv("USE_BITS") == "True":
        logger.info("BitsAndBytes is used.")
        use_bits = True
        dataclass_instances.append(BitsAndBytesConfig)
    else:
        logger.info("BitsAndBytes is NOT used.")
        use_bits = False

    logger.info("Parse argument")

    tuple_dataclass = parsers.parse_args(
        parser=HfArgumentParser(dataclass_instances)
        )

    classes = post_process(tuple_dataclass,
                           logger)

    load_datasets_args, early_stop_args, shared_args,\
        training_args, *other_args = tuple_dataclass

    set_seed(training_args.seed)

    list_dataclases = list(tuple_dataclass[:4])

    for other_arg in other_args:
        if use_lora and\
            other_arg.__class__.__name__=="LoraConfig":
            lora_args = parsers.convert_dataclasses(
                other_arg, LoraConfig)
            list_dataclases.append(lora_args)

        if use_bits and\
            isinstance(other_arg, BitsAndBytesConfig):
            bits_args = other_arg
            list_dataclases.append(bits_args)

    tuple_dataclass = tuple(list_dataclases)

    del list_dataclases, other_args

    training_args: Union[TrainingArguments,Seq2SeqTrainingArguments]
    shared_args: SharedArugments  

    for data_cls in tuple_dataclass:
        logger.info("%s", data_cls)

    logger.info("Save parsed arguments")

    parsers.save_config(data_classes=tuple_dataclass,
                        file_name=f"{training_args.output_dir}/script_config.json")

    if shared_args.do_preprocessing:

        dataset = load_dataset(**parsers.module_asdict(
            data_class=load_datasets_args))

        logger.info("Prepare data for processor")

        prepared_dataset = dataset.map(process.prepare_documents_for_processor,
                                       fn_kwargs={"new_special_tokens_path":
                                                  shared_args.new_special_tokens_path,
                                                  "bos_token": shared_args.bos_token,
                                                  "eos_token": shared_args.eos_token,
                                                  },
                                       **process.get_map_arguments(dataset,
                                                                   batched=False,
                                                                   num_proc=
                                                                   shared_args.map_num_proc,
                                                                   )
                                       )

        logger.info("Load processor from %s",
                    shared_args.pretrained_model_name_or_path)

        processor = AutoProcessor.from_pretrained(
            pretrained_model_name_or_path=shared_args.pretrained_model_name_or_path,
            clean_up_tokenization_spaces=True
        )

        logger.info("Initial Vocab size %s", len(processor.tokenizer))

        logger.info("Align processor configuration with data")

        prep.prepare_processor(processor=processor,
                               new_special_tokens_or_path=shared_args.new_special_tokens_path,
                               height=shared_args.height,
                               width=shared_args.width,
                               )

        logger.info("Final Vocab size %s", len(processor.tokenizer))

        logger.info("Save processor to %s", shared_args.processor_save)

        processor.save_pretrained(shared_args.processor_save)

        logger.info("Transform image data and tokenize text data")

        processed_dataset = prepared_dataset.map(process.transform_and_tokenize,
                                                 fn_kwargs={
                                                     "processor": processor,
                                                     },
                                                 **process.get_map_arguments(
                                                     prepared_dataset,
                                                     batch_size=shared_args.map_batch_size,
                                                     writer_batch_size=shared_args.writer_batch_size)
                                                 )
        logger.info("Save Dataset to %s", shared_args.data_save)
        processed_dataset.save_to_disk(shared_args.data_save)

    else:

        logger.info("Load processor from %s", shared_args.processor_save)

        processor = AutoProcessor.from_pretrained(
            pretrained_model_name_or_path=shared_args.processor_save,
            clean_up_tokenization_spaces=True
        )

        logger.info("Load model from %s",
                    shared_args.pretrained_model_name_or_path)

        model_kwgs = {
            "pretrained_model_name_or_path": 
            shared_args.pretrained_model_name_or_path,
            "attn_implementation": "flash_attention_2",
            "cache_dir": shared_args.cache_model,
            "trust_remote_code":True,
            "quantization_config": bits_args if use_bits else None,
        }

        try:
            model = AutoModelForVision2Seq.from_pretrained(
                **model_kwgs
                )

        except ValueError as valerr:
            logger.error(valerr)
            model_kwgs.pop("attn_implementation")
            model = AutoModelForVision2Seq.from_pretrained(
                **model_kwgs
            )

        logger.info("Align model configuration with processor and data")
        prep.prepare_model(model=model, processor=processor,
                           train_type=train_type,
                           max_length=shared_args.max_length,
                           bos_token=shared_args.bos_token,
                           eos_token=shared_args.eos_token,)

        if use_bits:
            logger.info("Prepare model for kbit training.")
            model = prepare_model_for_kbit_training(model)

        if use_lora:
            logger.info("Get Peft Model.")
            model = get_peft_model(model=model,
                                   peft_config=lora_args)
            model.print_trainable_parameters()

        logger.info("Configure metrics")
        metrics = EvalMetrics(
            device=main_device,
            tokenizer=processor.tokenizer,
            cls=classes,
            log_dir=training_args.logging_dir,
            new_special_tokens_path=shared_args.new_special_tokens_path,
            batched=training_args.batch_eval_metrics
        )

        logger.info("Set up callbacks")
        early_stop = EarlyStoppingCallback(**parsers.module_asdict(
            data_class=early_stop_args))

        logger.info("Load Dataset from %s", shared_args.data_save)

        processed_dataset = DatasetDict.load_from_disk(shared_args.data_save)

        keys = get_sorted_key(dataset=processed_dataset)

        logger.info("Set dataset format to torch tensor")

        processed_dataset.set_format("pt")

        logger.info("Create Trainer")

        trainer = trainer_class(
            model=model,
            args=training_args,
            train_dataset=processed_dataset[keys[0]],
            eval_dataset=processed_dataset[keys[2]],
            compute_metrics=metrics.compute_metrics,
            preprocess_logits_for_metrics=EvalMetrics.preprocess_logits_for_metrics,
            callbacks=[early_stop]
        )

        logger.info("Distributed type: %s",
                    trainer.accelerator.distributed_type)

        logger.info("The mixed precision is %s",
                    trainer.accelerator.mixed_precision)

        logger.info("Start training")

        out = trainer.train()

        logger.info("train output: %s", out)

        logger.info("Save model to %s", shared_args.model_save)

        if trainer.is_fsdp_enabled:
            trainer.accelerator.state.\
                fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

        trainer.save_model(output_dir=shared_args.model_save)

        logger.info("Start evaluation")

        eval_res = trainer.evaluate(eval_dataset=processed_dataset[keys[1]])
        logger.info("The evaluation result is %s", eval_res)

if __name__ == "__main__":
    main()
eljandoubi commented 1 week ago

Same problem in case of peft LORA with NO_WRAP. @muellerzr @SunMarc

image
eljandoubi commented 3 days ago

@muellerzr When auto_find_batch_size=True, the find_executable_batch_size function runs _inner_training_loop multiple times. Since self.accelerator.prepare wraps (apply auto policy that is undetectable via self._wrap_model) self.model before the CUDA out-of-memory error is thrown, the wrapped self.model in the next iteration causes the FSDP auto-wrapping error.