huggingface / peft

🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning.
https://huggingface.co/docs/peft
Apache License 2.0
16.27k stars 1.61k forks source link

Cannot use prefix tuning on quantized Codellama #2035

Closed MabelQi closed 4 weeks ago

MabelQi commented 2 months ago

System Info

I'm trying to PEFT with quantized LLMs. When I used prompt tuning, LoRA, and IA3, it works. However, when I use prefix tuning on 8-bit codellama-7b-hf, it reports the following error: image

Who can help?

@BenjaminBossan @sayakpaul @tmm1

Information

Tasks

Reproduction

# Set peft config
peft_type = PeftType.PREFIX_TUNING

peft_config = PrefixTuningConfig(
    task_type="SEQ_CLS",
    num_virtual_tokens=args.num_virtual_tokens
)

# Load model
model = AutoModelForSequenceClassification.from_pretrained(
    args.model_name_or_path, 
    num_labels=num_labels,
    load_in_4bit=True,
    device_map="auto"
)

model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

from peft import prepare_model_for_kbit_training

model = prepare_model_for_kbit_training(model)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
logger.info(f"Prefix Tuning-Trainable parameters: {model.get_nb_trainable_parameters()}")

if "deepseekcoder" or "starcoder" in args.model_name_or_path:
    model.config.pad_token_id = tokenizer.pad_token_id
    model.resize_token_embeddings(len(tokenizer))

# Instantiate optimizer
if args.optimizer.lower() == "adamw":
    optimizer = AdamW(model.parameters(), lr=args.learning_rate)

# Instantiate scheduler
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0.06 * (len(train_dataloader) * args.num_epochs),
    num_training_steps=(len(train_dataloader) * args.num_epochs)
)

total_steps = 0
best_validation_loss = float("inf")
peak_memory = 0
if use_cuda:
    model.cuda()

# Training
start_time = time.time()
for epoch in range(args.num_epochs):
    model.train()
    train_loss = 0.0

    progress_bar_train = tqdm(
        total=len(train_dataloader), 
        desc=f"Training epoch {epoch + 1}",
        position=0,
        mininterval=1,
        leave=True
    )

    for step, batch in enumerate(train_dataloader):
        total_steps += 1
        batch = {k: v.cuda() for k, v in batch.items()} if use_cuda else batch
        outputs = model(**batch)
        loss = outputs.loss
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        if step % 5 == 0:
            progress_bar_train.set_postfix({"loss": loss.item()})
            progress_bar_train.update(5)

        current_memory = torch.cuda.max_memory_allocated()
        if current_memory > peak_memory:
            peak_memory = current_memory

    progress_bar_train.close()

    avg_train_loss = train_loss / len(train_dataloader)
    logger.info(f"Epoch {epoch + 1} - Training loss: {avg_train_loss}")
    print(f"Epoch {epoch + 1} - Training loss: {avg_train_loss}")

    # Validation
    model.eval()
    total_validation_loss = 0.0

    progress_bar_valid = tqdm(
        total=len(valid_dataloader),
        desc=f"Validation epoch {epoch + 1}",
        position=0,
        mininterval=1,
        leave=True
    )

    for step, batch in enumerate(valid_dataloader):
        batch = {k: v.cuda() for k, v in batch.items()} if use_cuda else batch
        with torch.no_grad():
            outputs = model(**batch)
            loss = outputs.loss
            total_validation_loss += loss.item()

        if step % 5 == 0:
            progress_bar_valid.update(5)
    progress_bar_valid.close()

    avg_validation_loss = total_validation_loss / len(valid_dataloader)
    if avg_validation_loss < best_validation_loss:
        best_validation_loss = avg_validation_loss
        best_model_path = os.path.join(args.output_dir, model_name, f"prefix_tuning_seed_{args.seed}", "best_model")
        os.makedirs(best_model_path, exist_ok=True)
        model.save_pretrained(best_model_path)

    logger.info(f"Epoch {epoch + 1} - Validation loss: {avg_validation_loss}")
    print(f"Epoch {epoch + 1} - Validation loss: {avg_validation_loss}")

Expected behavior

I want to fine tune 8bit codellama-7b with prefix tuning

llCurious commented 2 months ago

Same issue. Any progress here?

BenjaminBossan commented 2 months ago

Thanks for reporting. Yes, this is a known issue that was introduced by introducing kv-cache to some model architectures in recent transformers versions, and that is affecting prefix tuning. We have a long discussion in #869 which also mentions some workarounds.

If this is an option for you, you could also try older transformers versions (e.g. 4.36.0 or older should work).

At the moment, I'm still figuring out how we can best make these recent transformers changes compatible with prefix-tuning, but unfortunately it's not an easy thing to fix.

llCurious commented 2 months ago

Thx to your quick reply. @BenjaminBossan The workaround indeed works in my case. Yet, I found that the loss for prefix-tuning and p-tuning varies a lot on the same model and dataset.

For example, on Qwen2-1.5B and alpaca-cleaned, prefix-tuning yields ~10, while p-tuning yields ~1. Do you have any ideas on this phenomenon?

BenjaminBossan commented 2 months ago

For example, on Qwen2-1.5B and alpaca-cleaned, prefix-tuning yields ~10, while p-tuning yields ~1. Do you have any ideas on this phenomenon?

Sorry, I don't have a lot of practical experience with these prompt tuning methods, maybe others can give some advise. Since the difference is so large, I would not exclude the possibility that there is a bug. Do you see that the training loss decreases? Did you try varying the hyper-parameters?

It could be worth a try to not use the workaround and instead checkout older transformers versions. If you see much better scores there, it is very likely that there is a bug in the workaround.

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.