huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.66k stars 27.16k forks source link

Unable to Resume Training from LoRA Checkpoints When Using FSDP #28320

Closed fabianlim closed 9 months ago

fabianlim commented 11 months ago

System Info

transformers==4.35.2 accelerate==0.23.0 peft==0.5.0

accelerate.yaml

compute_environment: LOCAL_MACHINE
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_forward_prefetch: true
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_transformer_layer_cls_to_wrap: "BertLayer"
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Who can help?

@pacman100 Following the recommendation https://huggingface.co/docs/trl/v0.7.4/en/sft_trainer#training-adapters to install a PeftSavingCallback to ensure that adapter.bin is saved. This will be the case when using FSDP since it is not a PretrainedModel, in which only the state_dict will be saved.

The recommendation above works great for saving the checkpoint, but does not work when resuming the checkpoint. This is because model_wrapped is neither a PretrainedModel nora PEFTModel, and the if-else conditions in Trainer._load_from_checkpoint will go all the way to load_sharded_checkpoint. This results in the following error:

Traceback (most recent call last):
  File "/dccstor/flim-ai4it/AI4IT/tofafm/src/scripts/lora_fsdp_bug_demo.py", line 183, in <module>
    main()
  File "/dccstor/flim-ai4it/AI4IT/tofafm/src/scripts/lora_fsdp_bug_demo.py", line 143, in main
    trainer.train(resume_from_checkpoint=True)
  File "/u/flim/miniconda3/envs/tofafm-rewrite2/lib/python3.10/site-packages/transformers/trainer.py", line 1555, in train
    return inner_training_loop(
  File "/u/flim/miniconda3/envs/tofafm-rewrite2/lib/python3.10/site-packages/transformers/trainer.py", line 1712, in _inner_training_loop
    self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
  File "/u/flim/miniconda3/envs/tofafm-rewrite2/lib/python3.10/site-packages/transformers/trainer.py", line 2132, in _load_from_checkpoint
    load_result = load_sharded_checkpoint(
  File "/u/flim/miniconda3/envs/tofafm-rewrite2/lib/python3.10/site-packages/transformers/modeling_utils.py", line 411, in load_sharded_checkpoint
    raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.")

The second issue with the recommendation, is that the FSDP optimizer sates are not saved in the PeftSavingCallback, so it will not be a clean fix.

I was wondering if you may have any thoughts on this. A possible hacky solution will be to override Trainer._load_from_checkpoint and use FSDP.summon_full_params to unshard the LoRA weights, and then call load_adapter, but it doesnt sound very clean given that it will not resume the FSDP optimizer.

Information

Tasks

Reproduction

  1. Run the below script with the above accelerate.yaml configurations. After at least 100 steps, when an adapter.bin checkpoint has been populated, stop.
  2. Rerun the script but now enabling trainer.train(resume_from_checkpoint=True).

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
from transformers import TrainingArguments, Trainer, TrainerCallback
import torch, os

def main(
    model_name: str = 'textattack/bert-base-uncased-SST-2',
):

    # we set the max sequence length here
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, model_max_length=512,
    )

    # load and tokenize a verys mall dataset
    raw_datasets = load_dataset('glue','sst2')

    # tokenization function
    def _tokenize_function(example, tokenizer):    
        return tokenizer(
            example['sentence'], truncation = True,
        )

    tokenized_datasets = raw_datasets.map(
        _tokenize_function, fn_kwargs = {'tokenizer': tokenizer}, batched=True
    )

    data_collator = DataCollatorWithPadding(
        tokenizer=tokenizer, return_tensors='pt'
    )

    model = AutoModelForSequenceClassification.from_pretrained(model_name)

    from peft import LoraConfig, get_peft_model
    model = get_peft_model(
        model, LoraConfig(
            r=8, lora_alpha=16,
            target_modules=['query', 'key', 'value'],
            task_type='SEQ_CLS'
        )
    )

    training_args = TrainingArguments(
        num_train_epochs = 1,
        output_dir = './results',
        per_device_train_batch_size = 8,
        per_device_eval_batch_size = 8,
        learning_rate = 2e-4,
        logging_steps = 50,
        save_strategy = 'steps',
        save_steps = 100,
        evaluation_strategy = 'steps',
        eval_steps = 100,
        save_total_limit = 2,
        metric_for_best_model = 'loss',
        greater_is_better = False,
        max_steps = 1000, # just make the demo quit after 1000 steps
        save_safetensors=False,
    )

    class PeftSavingCallback(TrainerCallback):
        def on_save(self, args, state, control, **kwargs):
            checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
            kwargs["model"].save_pretrained(checkpoint_path)

            if "pytorch_model.bin" in os.listdir(checkpoint_path):
                os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        data_collator=data_collator,
        callbacks=[PeftSavingCallback()], 
    )

    import functools
    from accelerate import DistributedType

    if trainer.accelerator.distributed_type == DistributedType.FSDP:

        from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy, _or_policy

        # inspired by
        # https://github.com/facebookresearch/llama-recipes/blob/main/src/llama_recipes/utils/fsdp_utils.py
        def lambda_policy_fn(module):
            if (
                len(list(module.named_children())) == 0
                and getattr(module, "weight", None) is not None
                and module.weight.requires_grad
            ):
                return True
            return False

        trainer.accelerator.state.fsdp_plugin.set_auto_wrap_policy(model)
        trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = functools.partial(
            _or_policy, policies=[
                functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn),
                trainer.accelerator.state.fsdp_plugin.auto_wrap_policy
            ])

    # checkpoints will be saved every 100 steps as `pytorch.bin`
    trainer.train()
    # trainer.train(resume_from_checkpoint=True) # activating this will throw the error

Expected behavior

  1. resume_from_checkpoint=True will resume the PEFT checkpoint recorded by PeftSavingCallback.
  2. [bonus]: the FSDP optimizer states can be resumed also.
github-actions[bot] commented 10 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

pacman100 commented 9 months ago

Hello @fabianlim, I think the PR https://github.com/huggingface/transformers/pull/28297 should resolve this.

fabianlim commented 9 months ago

@pacman100 yes I think so too, closing this issue.