huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.57k stars 26.69k forks source link

Trainer + FSDP - Llama 3.2 1B full fine-tuning: "size mismatch for model.embed_tokens.weight: copying a param with shape torch.Size([65667584]) from checkpoint, the shape in current model is torch.Size([128256, 2048])" #34080

Open brunopistone opened 1 week ago

brunopistone commented 1 week ago

System Info

transformers==4.45.1
peft==0.13.1
accelerate==0.34.2
bitsandbytes==0.44.0
datasets==2.20.0
evaluate==0.4.1
safetensors>=0.4.3
sagemaker==2.232.2
sentencepiece==0.2.0
scikit-learn==1.5.1
tokenizers>=0.19.1
py7zr

SageMaker Training Job: ml.g5.12xlarge Image: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.4-gpu-py311

Who can help?

@ArthurZucker @muellerzr @SunMarc

Information

Tasks

Reproduction

I'm running a full-fine tuning for Llama 3.2 1B with Amazon SageMaker. This is the script:

def train_fn(
        model_name,
        train_ds,
        test_ds=None,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        gradient_accumulation_steps=1,
        learning_rate=2e-4,
        num_train_epochs=1,
        fsdp="",
        fsdp_config=None,
        gradient_checkpointing=False,
        seed=42,
        token=None
):

    set_seed(seed)

    accelerator = Accelerator()

    if token is not None:
        login(token=token)

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Set Tokenizer pad Token
    tokenizer.pad_token = tokenizer.eos_token

    with accelerator.main_process_first():
        # tokenize and chunk dataset
        lm_train_dataset = train_ds.map(
            lambda sample: tokenizer(sample["text"]), remove_columns=list(train_ds.features)
        )

        print(f"Total number of train samples: {len(lm_train_dataset)}")

        if test_ds is not None:
            lm_test_dataset = test_ds.map(
                lambda sample: tokenizer(sample["text"]), remove_columns=list(test_ds.features)
            )

            print(f"Total number of test samples: {len(lm_test_dataset)}")
        else:
            lm_test_dataset = None

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        use_cache=not gradient_checkpointing,
        cache_dir="/tmp/.cache"
    )

    if gradient_checkpointing:
        model.gradient_checkpointing_enable()

    if fsdp != "" and fsdp_config is not None:
        fsdp_configurations = {
            "fsdp": fsdp,
            "fsdp_config": fsdp_config,
            "gradient_checkpointing_kwargs": {
                "use_reentrant": False
            }
        }
    else:
        fsdp_configurations = dict()

    trainer = transformers.Trainer(
        model=model,
        train_dataset=lm_train_dataset,
        eval_dataset=lm_test_dataset if lm_test_dataset is not None else None,
        args=transformers.TrainingArguments(
            per_device_train_batch_size=per_device_train_batch_size,
            per_device_eval_batch_size=per_device_eval_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            gradient_checkpointing=gradient_checkpointing,
            logging_strategy="steps",
            logging_steps=1,
            log_on_each_node=False,
            num_train_epochs=num_train_epochs,
            learning_rate=learning_rate,
            bf16=False,
            ddp_find_unused_parameters=False,
            save_strategy="no",
            output_dir="outputs",
            **fsdp_configurations
        ),
        data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
    )

    trainer.train()

    if trainer.is_fsdp_enabled:
        print("Set state to FULL_STATE_DICT")
        trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

    print("Saving model")
    trainer.save_model("/opt/ml/model")

    if accelerator.is_main_process:
        print("Saving tokenizer")
        tokenizer.save_pretrained("/opt/ml/model")

model_id = "meta-llama/Llama-3.2-1B-Instruct"

train_fn(
    model_id,
    train_ds=train_dataset,
    test_ds=test_dataset,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    num_train_epochs=3,
    fsdp="full_shard auto_wrap offload",
    fsdp_config={
        'backward_prefetch': 'backward_pre',
        'cpu_ram_efficient_loading': True,
        'offload_params': True,
        'forward_prefetch': False,
        'use_orig_params': True
    }
)

train_dataset:

Dataset({
    features: ['text'],
    num_rows: 140
})

train_dataset[0]["text"] (mock):

<|begin_of_text|><|start_header_id|>user<|end_header_id|>What dimensions of the Llama 3.2 model are available?<|eot_id|><|start_header_id|>assistant<|end_header_id|>Llama 3.2 offers several model dimensions to cater to different use cases and computational requirements: Text-Only Models 1\ 1B parameter model: This is the smallest Llama 3.2 model, designed for lightweight text processing tasks and on-device applications 2\ 3B parameter model: A slightly larger text-only model that still maintains efficiency for edge devices and mobile applications. Multimodal Vision Models: 1\ 11B parameter model: This is the smaller of the two vision-capable models, suitable for efficient deployment and development on consumer-grade GPUs 2\ 90B parameter model: The largest Llama 3.2 model, designed for large-scale applications and advanced image reasoning tasks<|end_of_text|><|eot_id|>

When I'm trying to load the model with the following script:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("./model")
model = AutoModelForCausalLM.from_pretrained("./model", trust_remote_code=True).to("cuda")

I have the following exception:

model = AutoModelForCausalLM.from_pretrained("./model", trust_remote_code=True).to("cuda")

File [/opt/conda/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py:564](https://ahwj8sa5y2ar84p.studio.us-east-1.sagemaker.aws/opt/conda/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py#line=563), in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    562 elif type(config) in cls._model_mapping.keys():
    563     model_class = _get_model_class(config, cls._model_mapping)
--> 564     return model_class.from_pretrained(
    565         pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    566     )
    567 raise ValueError(
    568     f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
    569     f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
    570 )

File [/opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:4014](https://ahwj8sa5y2ar84p.studio.us-east-1.sagemaker.aws/opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py#line=4013), in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
   4004     if dtype_orig is not None:
   4005         torch.set_default_dtype(dtype_orig)
   4007     (
   4008         model,
   4009         missing_keys,
   4010         unexpected_keys,
   4011         mismatched_keys,
   4012         offload_index,
   4013         error_msgs,
-> 4014     ) = cls._load_pretrained_model(
   4015         model,
   4016         state_dict,
   4017         loaded_state_dict_keys,  # XXX: rename?
   4018         resolved_archive_file,
   4019         pretrained_model_name_or_path,
   4020         ignore_mismatched_sizes=ignore_mismatched_sizes,
   4021         sharded_metadata=sharded_metadata,
   4022         _fast_init=_fast_init,
   4023         low_cpu_mem_usage=low_cpu_mem_usage,
   4024         device_map=device_map,
   4025         offload_folder=offload_folder,
   4026         offload_state_dict=offload_state_dict,
   4027         dtype=torch_dtype,
   4028         hf_quantizer=hf_quantizer,
   4029         keep_in_fp32_modules=keep_in_fp32_modules,
   4030         gguf_path=gguf_path,
   4031     )
   4033 # make sure token embedding weights are still tied if needed
   4034 model.tie_weights()

File [/opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:4559](https://ahwj8sa5y2ar84p.studio.us-east-1.sagemaker.aws/opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py#line=4558), in PreTrainedModel._load_pretrained_model(cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, _fast_init, low_cpu_mem_usage, device_map, offload_folder, offload_state_dict, dtype, hf_quantizer, keep_in_fp32_modules, gguf_path)
   4555     if "size mismatch" in error_msg:
   4556         error_msg += (
   4557             "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
   4558         )
-> 4559     raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
   4561 if len(unexpected_keys) > 0:
   4562     archs = [] if model.config.architectures is None else model.config.architectures

RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM:
    size mismatch for model.embed_tokens.weight: copying a param with shape torch.Size([65667584]) from checkpoint, the shape in current model is torch.Size([128256, 2048]).
    size mismatch for model.norm.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([2048]).
    You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

Expected behavior

The previously shared script works properly if I'm fine-tuning with mixed precision bfloat16, with quantization using bitsandbytes, and with LoRA. I suspect there is something wrong in how the model is saved. The expected behavior is that the model is properly loaded an usable for inference

ZihanWangKi commented 1 day ago

I experienced a similar problem when training with FSDP + Trainer. When DDP instead of FSDP is used, the problem disappears.

I fixed it by using the nightly version of PyTorch as of Oct 16, 2024.

My (quite uneducated) guess is that it is related to this issue. However, simply disabling weight tying did not work for me, while upgrading the pytorch version did.