huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.47k stars 26.89k forks source link

Trainer with resume_from_checkpoint does not work with multiple Peft Adapters #30478

Closed claralp closed 6 months ago

claralp commented 6 months ago

System Info

Since it it possible to have multiple Peft adapters in the same model, it should also be possible to resume a training of such models from checkpoint with transformers.Trainer.train(resume_from_checkpoint=True|"path")
This is needed to resume e.g. DPO/KTO trainings loaded with 2 adapters, as in load-the-adapter-twice.

This is due to the fact that if multiple adapters exist, their weights get saved in subdirectories of the checkpoint, not directly in there. So no adapter_model.bin or adapter_model.safetensors can be found.

Anyone working on this or should I come up with a solution?

Who can help?

@ArthurZucker @younesbelkada @muellerzr

Information

Tasks

Reproduction

from https://huggingface.co/docs/trl/dpo_trainer#using-option-3---load-the-adapter-twice

# Load the base model.
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)
model = AutoModelForCausalLM.from_pretrained(
    "mistralai/mixtral-8x7b-v0.1",
    load_in_4bit=True,
    quantization_config=bnb_config,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
model.config.use_cache = False

# Load the adapter.
model = PeftModel.from_pretrained(
    model,
    "/path/to/peft",
    is_trainable=True,
    adapter_name="train",
)
# Load the adapter a second time, with a different name, which will be our reference model.
model.load_adapter("/path/to/peft", adapter_name="reference")

# Initialize the trainer, without a ref_model param.
dpo_trainer = DPOTrainer(
    model,
    ...
    model_adapter_name="train",
    ref_adapter_name="reference",
)

After a checkpoint is saved, interrupt the training and resume with: dpo_trainer.train(resume_from_checkpoint=True)

But this throws the following error: ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

Expected behavior

the training should resume

claralp commented 6 months ago

also interesting for @lewtun @kashif