huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.77k stars 27.18k forks source link

Inference with FSDP during training affects checkpoints #34530

Open pandrei7 opened 1 month ago

pandrei7 commented 1 month ago

System Info

Output from transformers-cli env:

 - `transformers` version: 4.45.2
 - Platform: Linux-6.1.0-21-cloud-amd64-x86_64-with-glibc2.36
 - Python version: 3.12.5
 - Huggingface_hub version: 0.25.0
 - Safetensors version: 0.4.5
 - Accelerate version: 1.0.1
 - Accelerate config:    not found
 - PyTorch version (GPU?): 2.5.0+cu124 (True)
 - Tensorflow version (GPU?): 2.17.0 (False)
 - Flax version (CPU?/GPU?/TPU?): not installed (NA)
 - Jax version: not installed
 - JaxLib version: not installed
 - Using distributed or parallel set-up in script?: just using the Trainer, and running with accelerate
 - Using GPU in script?: I'm running on GPUs
 - GPU type: NVIDIA H100 80GB HBM3

Relevant environment and library versions:

 Linux Debian 6.1.90-1
 CUDA version: 12.4

 accelerate==1.0.1
 datasets==3.0.1
 torch==2.4.1
 torchaudio==2.4.1
 torchvision==0.19.1
 transformers==4.45.2

Who can help?

No response

Information

Tasks

Reproduction

Hello! I'm running into an issue with checkpoints saved when training an LLM with FSDP and the default HuggingFace trainer, if I also do inference during training. I provide code at the end of this post for clarity.

I also asked this on the forum before coming here, but I haven't found a solution yet.

What I'm trying to achieve

I want to write a callback to monitor model outputs on a validation set throughout the training process. This requires doing inference with model.generate(). Since I'm also using FSDP, I need to summon all weights on a single device, as described in this Github issue.

My issue

The callback I provide below seems to work fine for evaluation, but it affects the checkpoints that get saved. Specifically, when unsharding the final checkpoint and trying to replicate the results I see from my training script, I get different, much worse results from the checkpoint.

To test this, I trained an LLM to memorize a simple phrase: "Two times 10 equals 20.". At the end of training, my callback reports the completions I expect, meaning the model trained well. However, if I load the checkpoint from disk and feed it the same prompts, I get this:

 # With callback
 # Outputs from the training script, after training.
 "Two"                 -> "times 10 equals 20."
 "Two times"           -> "10 equals 20."
 "Two times 10"        -> "equals 20."
 "Two times 10 equals" -> "20."
 # Outputs from the checkpoint loaded from disk.
 "Two"                 -> "               "
 "Two times"           -> "equals               "
 "Two times 10"        -> "               "
 "Two times 10 equals" -> "               "

This does not happen if I don't run the callback during training. If I remove it, the checkpoint produced outputs the expected results:

 # Without callback
 # Outputs from the checkpoint loaded from disk.
 "Two"                 -> "times 10 equals 20."
 "Two times"           -> "10 equals 20."
 "Two times 10"        -> "equals 20."
 "Two times 10 equals" -> "20."

To make extra sure, I also tried this experiment with DDP instead of FSDP (I removed the summon instruction). The DDP checkpoint is correct regardless of using my callback or not.

 # With DDP
 # Outputs from the training script, after training.
 "Two"                 -> "times 10 equals 20."
 "Two times"           -> "10 equals 20."
 "Two times 10"        -> "equals 20."
 "Two times 10 equals" -> "20."
 # Outputs from the checkpoint loaded from disk.
 "Two"                 -> "times 10 equals 20."
 "Two times"           -> "10 equals 20."
 "Two times 10"        -> "equals 20."
 "Two times 10 equals" -> "20."

I believe this points to summon_full_params being the problem. Do you think this could be a problem with the library, or maybe with my implementation? Any ideas or advice? Thank you!

Minimal example

main.py ```python from typing import cast import accelerate import datasets import torch import transformers from torch.distributed import fsdp class ValidCallback(transformers.TrainerCallback): def __init__(self, tokenizer: transformers.PreTrainedTokenizerBase, dataset: datasets.Dataset) -> None: super().__init__() self.tokenizer = tokenizer self.dataset = dataset def on_epoch_end( self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs, ) -> None: if state.epoch is None or int(state.epoch) % 25 != 0: return model = cast(transformers.PreTrainedModel, kwargs["model"]) with torch.no_grad(): self.run(model) @torch.no_grad() def run(self, model: transformers.PreTrainedModel) -> None: model.eval() for batch in self.dataset.iter(batch_size=7): encoding = self.tokenizer(batch["text"], return_tensors="pt", padding=True).to(model.device) with fsdp.FullyShardedDataParallel.summon_full_params(model): outputs = model.generate( inputs=encoding.input_ids, attention_mask=encoding.attention_mask, pad_token_id=self.tokenizer.eos_token_id, max_new_tokens=16, do_sample=False, ) predictions = self.tokenizer.batch_decode( outputs[:, encoding.input_ids.shape[1] :], # Skip the returned prompt. skip_special_tokens=True, clean_up_tokenization_spaces=True, ) if accelerate.PartialState().is_main_process: print(predictions) def main() -> None: # Load model and tokenizer. checkpoint = "mistralai/Mistral-7B-v0.3" tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint) tokenizer.padding_side = "left" if not tokenizer.pad_token: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16) model.resize_token_embeddings(len(tokenizer)) # Load and prepare a toy dataset. def tokenize_function(examples): tokenized = tokenizer(examples["text"], max_length=32, padding="max_length", truncation=True) tokenized["labels"] = cast(list, tokenized["input_ids"]).copy() return tokenized train_dataset = datasets.Dataset.from_dict({"text": ["Two times 10 equals 20."] * 100}) valid_dataset = datasets.Dataset.from_dict( {"text": ["Two", "Two times", "Two times 10", "Two times 10 equals", "Two times 10 equals 20."]} ) train_dataset = train_dataset.map( tokenize_function, batched=True, remove_columns=list(train_dataset.features) ) # Train. trainer = transformers.Trainer( model=model, train_dataset=train_dataset, args=transformers.TrainingArguments( output_dir="./output-minimal", save_strategy="steps", save_steps=1_000_000, overwrite_output_dir=True, remove_unused_columns=False, optim="adamw_torch_fused", bf16=True, learning_rate=1e-2, num_train_epochs=100, per_device_train_batch_size=1, ddp_timeout=9999999, report_to=[], ), callbacks=[ ValidCallback(tokenizer, valid_dataset), ], ) trainer.train() if __name__ == "__main__": main() ```
fsdp.yaml ```yaml compute_environment: LOCAL_MACHINE debug: false distributed_type: FSDP downcast_bf16: 'no' enable_cpu_affinity: false fsdp_config: fsdp_activation_checkpointing: false fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_backward_prefetch: BACKWARD_PRE fsdp_cpu_ram_efficient_loading: true fsdp_forward_prefetch: false fsdp_offload_params: false fsdp_sharding_strategy: FULL_SHARD fsdp_state_dict_type: SHARDED_STATE_DICT fsdp_sync_module_states: true fsdp_use_orig_params: true machine_rank: 0 main_training_function: main mixed_precision: bf16 num_machines: 1 num_processes: 8 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false ```

I run my code on Slurm, using this command:

 srun bash -c "accelerate launch \
     --config_file fsdp.yaml \
     --main_process_ip $(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) \
     --main_process_port 6000 \
     --machine_rank \$SLURM_PROCID \
     main.py"

Expected behavior

I would expect the checkpoint saved on disk to produce the same outputs as those shown by the script after training.

Rocketknight1 commented 1 month ago

cc @muellerzr @SunMarc !

SunMarc commented 3 weeks ago

Thanks for the nice report ! This is indeed a very strange behavior. Could you try to see if you get the same model at the end with/without the callback. At first glance, it looks like with fsdp.FullyShardedDataParallel.summon_full_params(model) is the potential culprit. Could you try to just call alone in on_epoch_end ?

alexandru-dinu commented 3 weeks ago

Hey @SunMarc! Just a note re:

Could you try to see if you get the same model at the end with/without the callback.

I am following this issue and also replied to the HuggingFace forum. TL;DR when unsharding the model, only the *.safetensors file differ between runs with and without the callback -- so we don't get the same model.

pandrei7 commented 3 weeks ago

Hi @SunMarc! Thanks a lot for looking into this!

I confirm that I get different model weights depending on whether I use the callback or not. All three Safetensors files show up with diff.

I tried to run the generation in on_epoch_end without calling summon_full_params, but I get this error when I reach model.generate:

RuntimeError: 'weight' must be 2-D

I assume this behaviour is expected, based on this comment. I hope this is what you were asking, but do tell if I got it wrong.

I looked a bit more into PyTorch's documentation for summon_full_params, and tried setting writeback=False, just to make sure. But it has no effect: predictions after training look fine, but the checkpoint is wrong.

Tataaa-cans commented 3 weeks ago

false

github-actions[bot] commented 18 hours ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

SunMarc commented 10 hours ago

cc @XuehaiPan if you have some time !

XuehaiPan commented 10 hours ago
    def on_epoch_end(
        self,
        args: transformers.TrainingArguments,
        state: transformers.TrainerState,
        control: transformers.TrainerControl,
        **kwargs,
    ) -> None:
        if state.epoch is None or int(state.epoch) % 25 != 0:
            return
        model = cast(transformers.PreTrainedModel, kwargs["model"])
        with torch.no_grad():
            self.run(model)

    @torch.no_grad()
    def run(self, model: transformers.PreTrainedModel) -> None:
        model.eval()

        for batch in self.dataset.iter(batch_size=7):
            encoding = self.tokenizer(batch["text"], return_tensors="pt", padding=True).to(model.device)

            with fsdp.FullyShardedDataParallel.summon_full_params(model):
                outputs = model.generate(
                    inputs=encoding.input_ids,
                    attention_mask=encoding.attention_mask,
                    pad_token_id=self.tokenizer.eos_token_id,
                    max_new_tokens=16,
                    do_sample=False,
                )

            predictions = self.tokenizer.batch_decode(
                outputs[:, encoding.input_ids.shape[1] :],  # Skip the returned prompt.
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )

            if accelerate.PartialState().is_main_process:
                print(predictions)

Hi @pandrei7, I wonder have you ever tried to remove the model.eval() statement or change the model back to training mode after validation?