huggingface / peft

🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning.
https://huggingface.co/docs/peft
Apache License 2.0
15.95k stars 1.55k forks source link

How to correctly use Prefixing Tuning? #869

Closed Vincent-Li-9701 closed 11 months ago

Vincent-Li-9701 commented 1 year ago

System Info

peft 0.5.0 transformers 4.32.0

Who can help?

No response

Information

Tasks

Reproduction

model = AutoModelForSeq2SeqLM.from_pretrained('bigscience/T0pp', load_in_8bit=True)
model = prepare_model_for_int8_training(model)
config = PrefixTuningConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    num_virtual_tokens=100,
    token_dim=model.config.hidden_size,
    num_transformer_submodules=1,
    num_attention_heads=model.config.num_heads,
    num_layers=model.config.num_layers,
    encoder_hidden_size=1792,
)
model = get_peft_model(model, config)

Expected behavior

I'm assuming num_layers, num_attention_heads, and token_dim need to match the base model. In the sample num_transformer_submodules is 1. But encoder-decoder has two transformers right? Should this be 2?

When I run the code above I got

File "/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 551, in forward
position_bias = position_bias + mask  # (batch_size, n_heads, seq_length, key_length)
RuntimeError: The size of tensor a (3) must match the size of tensor b (103) at non-singleton dimension 3

When I print out the shape of position_bias and mask. mask has 100 more tokens than position_bias seems like on the decoder side. It's also taking in the prefix embeddings

Vincent-Li-9701 commented 1 year ago

@pacman100 Would you mind taking a look?

LesterGong commented 1 year ago

have you been solved this problem?

Vincent-Li-9701 commented 1 year ago

@WhoopeeHg no unfortunately I decided to skip the prefix tuning part since I found that to be less effective than P-Tuning or LoRA on my dataset.

Vincent-Li-9701 commented 1 year ago

Based on my discussion with others, the problem seems to surface when we load the model with 8 bit quantization.

github-actions[bot] commented 11 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

vikram71198 commented 7 months ago

@pacman100 @BenjaminBossan this issue occurs even without 8 bit quantization and basically renders HF's integration of Prefix Tuning useless, unless this bug is fixed.

BenjaminBossan commented 7 months ago

@vikram71198 Could you please give more details? Ideally some code to reproduce the error and the PEFT version you're using.

vikram71198 commented 7 months ago

My environment:

transformers = 4.38.1 peft = 0.8.2

This is the code snippet I'm using which has been adapted from here :

import os
import random
import os
import numpy as np
import pandas as pd
import torch
import datasets
from torch.utils.data import Dataset, DataLoader
import transformers
import peft
import trl
import json
from pprint import pprint
import flash_attn
import accelerate
from transformers import BitsAndBytesConfig
from tqdm import tqdm as tqdm
import mlflow

max_length = 500
lr = 2e-4
num_epochs = 5
batch_size = 1
num_virtual_tokens = 30
random_seed = 42

model_name = "teknium/OpenHermes-2.5-Mistral-7B"
text_column = "Transcript"
label_column = "RFC"

def get_prompt(transcript: str) -> str:
    prompt = """Transcript:

{transcript}

---

I want you to act as a transcript analysis expert. I have provided you with a transcript between agent & customer above and your goal is to summarize the reason why the customer calls up the agent. If there is no discernible reason, output "No reason identified".

Answer:"""

    return prompt.format(transcript = transcript)

def get_mistral_prompt(transcript: str, system_message : str = "") -> str:
    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": get_prompt(transcript)}
    ]
    return tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(model_name, device_map = "auto", torch_dtype = torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained(model_name)

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

tokenizer.use_default_system_prompt = False

def preprocess_function(examples):
    batch_size = len(examples[text_column])
    # inputs = [f"{text_column} : {x} Label : " for x in examples[text_column]]
    inputs = [get_mistral_prompt(x) for x in examples[text_column]]
    targets = [str(x) for x in examples[label_column]]
    model_inputs = tokenizer(inputs)
    labels = tokenizer(targets)
    for i in range(batch_size):
        sample_input_ids = model_inputs["input_ids"][i]
        label_input_ids = labels["input_ids"][i] + [tokenizer.pad_token_id]
        # print(i, sample_input_ids, label_input_ids)
        model_inputs["input_ids"][i] = sample_input_ids + label_input_ids
        labels["input_ids"][i] = [-100] * len(sample_input_ids) + label_input_ids
        model_inputs["attention_mask"][i] = [1] * len(model_inputs["input_ids"][i])
    # print(model_inputs)
    for i in range(batch_size):
        sample_input_ids = model_inputs["input_ids"][i]
        label_input_ids = labels["input_ids"][i]

        #padding if length of this example is smaller than max_seq_length
        model_inputs["input_ids"][i] = [tokenizer.pad_token_id] * (
            max_length - len(sample_input_ids)
        ) + sample_input_ids
        model_inputs["attention_mask"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[
            "attention_mask"
        ][i]
        labels["input_ids"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids

        #if sequence length of this example exceeds max_seq_length, we're performing truncation here
        model_inputs["input_ids"][i] = torch.tensor(model_inputs["input_ids"][i][:max_length])
        model_inputs["attention_mask"][i] = torch.tensor(model_inputs["attention_mask"][i][:max_length])
        labels["input_ids"][i] = torch.tensor(labels["input_ids"][i][:max_length])
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

transcripts = ["""
Agent: Hello may I know why you're calling today?
Customer: Yeah, I'm calling to cancel my insurance policy
Agent: May I ask you why you chose to do so?
Customer: Yeah, I'm just not interested in the product you have to offer anymore
Agent: That is totally understandable""",
"""
Agent: Hello
Customer: Hello
Agent: Have a nice day
Customer: Thanks"""]

rfcs = ["Customer called to cancel their insurance policy.", "No reason identified."]

import pandas as pd
df = pd.DataFrame({"Transcript": transcripts, "RFC": rfcs})

from datasets import Dataset, DatasetDict
rfc_dataset = DatasetDict()
rfc_dataset["train"] = Dataset.from_pandas(df)

formatted_dataset = rfc_dataset.map(
    preprocess_function,
    batched=True,
    num_proc=1,
    remove_columns=rfc_dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

formatted_dataset = formatted_dataset.shuffle(seed = random_seed)

train_dataset = formatted_dataset["train"]

from transformers import default_data_collator

train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)

from peft import get_peft_config, get_peft_model, PrefixTuningConfig, TaskType, PeftType
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup

peft_config = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=num_virtual_tokens, prefix_projection = False)

model = get_peft_model(model, peft_config)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),
)

run_name = "prefix-tuning-v1"

with mlflow.start_run(run_name = run_name):
    for epoch in tqdm(range(num_epochs), total = num_epochs):

        model.train()
        total_loss = 0

        for step, batch in enumerate(tqdm(train_dataloader)):
            batch = {k: v.to(torch.device("cuda")) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            total_loss += loss.detach().float()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        train_epoch_loss = total_loss / len(train_dataloader)
        train_ppl = torch.exp(train_epoch_loss)
        print(f"{epoch=}: {train_ppl=} {train_epoch_loss=}")

        mlflow.log_param("epoch", epoch + 1, step = epoch)
        mlflow.log_param("train_loss", train_epoch_loss, step = epoch)
        mlflow.log_param("train perplexity", train_ppl, step = epoch)

    mlflow.end_run()

And this is the exact error message I'm seeing:

RuntimeError: The size of tensor a (530) must match the size of tensor b (500) at non-singleton dimension 3

The difference between 530 & 500 is num_virtual_tokens (30 in this case) and it seems like this always happens.

There was some investigation done by @Vincent-Li-9701 here, but the bug is still unresolved.

@BenjaminBossan please let me know if you need anything else.

BenjaminBossan commented 6 months ago

Thanks for providing this reproducer. I could condense the code to the following:

import torch
from transformers import MistralConfig, MistralForCausalLM
from peft import PrefixTuningConfig, get_peft_model

# using small mistral for testing, real mistral would also work
model_config = MistralConfig(
    vocab_size=32000,
    hidden_size=512,
    max_position_embeddings=32768,
    num_attention_heads=16,
    num_hidden_layers=8,
    num_key_value_heads=4,
)
model = MistralForCausalLM(model_config)

config = PrefixTuningConfig(task_type="CAUSAL_LM", num_virtual_tokens=30)
model = get_peft_model(model, config)
model.config.use_cache = False

input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]])
attention_mask = torch.tensor([[1, 1, 1], [1, 0, 1]])
outputs = model(input_ids=input_ids, attention_mask=attention_mask)

which gives the same error:

    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
  File "...site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "...site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/vinh/work/forks/peft/src/peft/peft_model.py", line 1126, in forward
    return self.base_model(
  File "...site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "...site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "...site-packages/transformers/models/mistral/modeling_mistral.py", line 1157, in forward
    outputs = self.model(
  File "...site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "...site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "...site-packages/transformers/models/mistral/modeling_mistral.py", line 1004, in forward
    attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
  File "...site-packages/transformers/modeling_attn_mask_utils.py", line 398, in _prepare_4d_causal_attention_mask_for_sdpa
    expanded_4d_mask = attn_mask_converter.to_4d(
  File "...site-packages/transformers/modeling_attn_mask_utils.py", line 137, in to_4d
    expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
RuntimeError: The size of tensor a (33) must match the size of tensor b (3) at non-singleton dimension 3

Unfortunately, even after some digging, I couldn't figure out how to fix this issue yet. I asked around, let's see if someone has a solution.

vikram71198 commented 6 months ago

One (important) difference between the implementations of Prefix Tuning & other PeFT techniques is evident in the forward() method of PeftModelForCausalLM here.

if peft_config.peft_type == PeftType.PREFIX_TUNING:
        past_key_values = self.get_prompt(batch_size)
        return self.base_model(
                input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, **kwargs
        )
else:
    if inputs_embeds is None:
        inputs_embeds = self.word_embeddings(input_ids)
    # concat prompt labels
    if labels is not None:
        prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device)
        kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
    prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids)
    prompts = prompts.to(inputs_embeds.dtype)
    inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
    return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

It seems like past_key_values is passed as a separate argument for Prefix Tuning, while its prepended to the input otherwise.

Could this be causing the issue?

I picked up on this from #870.

BenjaminBossan commented 6 months ago

It seems like past_key_values is passed as a separate argument for Prefix Tuning, while its prepended to the input otherwise.

It could be related, but I don't know enough about these techniques to be sure. Whatever the cause is, it has to be depend on model architecture, as some models work and others don't. I modified the above snippet to check multiple models like so:

import torch
from transformers import MistralConfig, MistralForCausalLM, AutoModelForCausalLM
from peft import PrefixTuningConfig, get_peft_model

def get_model(name):
    if name == "mistral":
        model_config = MistralConfig(
            vocab_size=32000,
            hidden_size=512,
            max_position_embeddings=32768,
            num_attention_heads=16,
            num_hidden_layers=8,
            num_key_value_heads=4,
        )
        return MistralForCausalLM(model_config)

    return AutoModelForCausalLM.from_pretrained(name)

for name in ("gpt2", "facebook/opt-125m", "bigscience/bloomz-560m",  "HuggingFaceH4/tiny-random-LlamaForCausalLM", "mistral"):
    config = PrefixTuningConfig(task_type="CAUSAL_LM", num_virtual_tokens=30)
    model = get_model(name)
    model = get_peft_model(model, config)
    model.config.use_cache = False

    input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]])
    attention_mask = torch.tensor([[1, 1, 1], [1, 0, 1]])
    try:
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        print(f"PASS: model {name} passed")
    except Exception as e:
        print(f"FAIL: model {name} failed with {e}")

and I got:

PASS: model gpt2 passed
PASS: model facebook/opt-125m passed
PASS: model bigscience/bloomz-560m passed
FAIL: model HuggingFaceH4/tiny-random-LlamaForCausalLM failed with 'tuple' object has no attribute 'update'
FAIL: model mistral failed with The size of tensor a (33) must match the size of tensor b (3) at non-singleton dimension 3

with transformers 4.38.1.

Regarding the Llama error, that could be some kv-cache thing, as I get a different error with older transformers versions -- with 4.36 and 4.37, got the same as error as for mistral. Whether the mistral error could also be related to that, I'm not sure.

Pinging @younesbelkada @pacman100 for help.

vikram71198 commented 6 months ago

I slightly modified your script to:

import torch
from transformers import MistralConfig, MistralForCausalLM, AutoModelForCausalLM
from peft import PrefixTuningConfig, get_peft_model

def get_model(name):
    if name == "mistral":
        model_config = MistralConfig(
            vocab_size=32000,
            hidden_size=512,
            max_position_embeddings=32768,
            num_attention_heads=16,
            num_hidden_layers=8,
            num_key_value_heads=4,
        )
        return MistralForCausalLM(model_config)

    return AutoModelForCausalLM.from_pretrained(name)

for name in ("gpt2", "facebook/opt-125m", "bigscience/bloomz-560m",  "HuggingFaceH4/tiny-random-LlamaForCausalLM", "mistral", "teknium/OpenHermes-2.5-Mistral-7B", "meta-llama/Llama-2-7b-chat-hf"):
    config = PrefixTuningConfig(task_type="CAUSAL_LM", num_virtual_tokens=30)
    model = get_model(name)
    model = get_peft_model(model, config)
    model = model.to(torch.device("cuda"))
    model.config.use_cache = False

    input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]]).to(torch.device("cuda"))
    attention_mask = torch.tensor([[1, 1, 1], [1, 0, 1]]).to(torch.device("cuda"))

    try:
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        print(f"PASS: model {name} passed")
    except Exception as e:
        print(f"FAIL: model {name} failed with {e}")
    finally:
        torch.cuda.empty_cache()
        model.to(torch.device("cpu"))
        del model

With the script above, this is what I see with transformers == 4.34.1 :

PASS: model gpt2 passed
PASS: model facebook/opt-125m passed
PASS: model bigscience/bloomz-560m passed
PASS: model HuggingFaceH4/tiny-random-LlamaForCausalLM passed
FAIL: model mistral failed with Sizes of tensors must match except in dimension 2. Expected size 16 but got size 4 for tensor number 1 in the list.
FAIL: model teknium/OpenHermes-2.5-Mistral-7B failed with Sizes of tensors must match except in dimension 2. Expected size 32 but got size 8 for tensor number 1 in the list.
PASS: model meta-llama/Llama-2-7b-chat-hf passed

So, it seems like LlamaForCausalLM is working with 4.34.1.

But, clearly there's something nefarious going on. Hope we can find & fix this, because this locks out the community from using this fine tuning method on many, many models that are actually more relevant now.

vikram71198 commented 6 months ago

So, I downgraded to transformers == 4.34.1 and Prefix Tuning seemed to run bug free with meta-llama/Llama-2-7b-chat-hf.

But, after fine tuning, I mostly saw gibberish outputs from the fine tuned model. I'm about 95% sure that there's nothing wrong with my implementation either & am starting to think that this bug in Prefix Tuning silently creeps up in cases where we don't get the aforementioned RuntimeError too.

So, it seems to me, that currently Prefix Tuning just doesn't work at all.

Hoping someone has a fix for this.

@younesbelkada @pacman100 @BenjaminBossan

vikram71198 commented 5 months ago

Any updates on this?

@BenjaminBossan @younesbelkada @pacman100

GGG-c commented 2 months ago

same

BenjaminBossan commented 2 months ago

Good and bad news. With the latest transformers (4.43.2) and PEFT version (0.12.0), the tensor size mismatch error is no longer occurring. However, there is a new error:

'tuple' object has no attribute 'get_seq_length'

This is because past_key_values provided by prefix tuning is a tuple of tensors but transformers assumes it's either a Cache instance or None. I have asked internally about clarification of the type and will report back if I get an answer.

jrrw10 commented 1 month ago

This is because past_key_values provided by prefix tuning is a tuple of tensors but transformers assumes it's either a Cache instance or None

I attempted changing past_key_values to a DynamicCache like so in PeftModel.get_prompt():

past_key_values = DynamicCache.from_legacy_cache(past_key_values)

Calling trainer.train() on a Mistral model is resulting in a shape mismatch during attention computation:

RuntimeError                              Traceback (most recent call last)
<ipython-input-9-3435b262f1ae> in <cell line: 1>()
----> 1 trainer.train()

15 frames
/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    471         is_causal = True if causal_mask is None and q_len > 1 else False
    472 
--> 473         attn_output = torch.nn.functional.scaled_dot_product_attention(
    474             query_states,
    475             key_states,

RuntimeError: The expanded size of the tensor (2078) must match the existing size (1054) at non-singleton dimension 3.  Target sizes: [4, 32, 1024, 2078].  Tensor sizes: [4, 1, 1024, 1054]

Any idea why this is? @BenjaminBossan

BenjaminBossan commented 1 month ago

@jrrw10 What version of transformers and PEFT are you using? Do you still get the same error when upgrading to the latest versions from main for the two packages?

jrrw10 commented 1 month ago

@BenjaminBossan Yes, I get this error with the latest version of main for both PEFT and transformers.

transformers==4.44.0.dev0 peft==0.12.1.dev0

BenjaminBossan commented 1 month ago

Okay, strange. I tried with transformers commit 51ab25e2932da15511ced35bcbdfa92d25c4794c and PEFT commit 269aba5303216984d03a8be2e6ef270a9130550f. Following your suggestion, I added

            from transformers import DynamicCache
            past_key_values = DynamicCache.from_legacy_cache(past_key_values)

before this line. Then I ran the example shown above and I get:

PASS: model gpt2 passed
PASS: model facebook/opt-125m passed
PASS: model bigscience/bloomz-560m passed
/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:546: UserWarning: `pad_token_id` should be positive but got -1. This will cause errors when batch generating, if there is padding. Please set `pad_token_id` explicitly by `model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation, and ensure your `input_ids` input does not have negative values.
  warnings.warn(
PASS: model HuggingFaceH4/tiny-random-LlamaForCausalLM passed
PASS: model llamafactory/tiny-random-Llama-3 passed
PASS: model mistral passed
PASS: model peft-internal-testing/tiny-dummy-qwen2 passed

(note that use_cache = False is unnecessary in the snippet)

jrrw10 commented 1 month ago

@BenjaminBossan Thank you for testing this out. I get the same output with the above testing for a forward pass. While the forward pass tests are successful, calling trainer.train() is what is resulting in the shape mismatch for me.

Debugging prints in transformers/models/mistral/modeling_mistral.py in MistralSdpaAttention.forward()reveal the following shapes of query_states, key_states, value_states, and causal_mask:

...
Before attention computation - Query States: torch.Size([4, 32, 1024, 128]), Key States: torch.Size([4, 32, 1054, 128]), Value States: torch.Size([4, 32, 1054, 128]), mask: torch.Size([4, 1, 1024, 1054])
Before attention computation - Query States: torch.Size([4, 32, 1024, 128]), Key States: torch.Size([4, 32, 2078, 128]), Value States: torch.Size([4, 32, 2078, 128]), mask: torch.Size([4, 1, 1024, 1054])

Which then hits the mismatch:

[/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py](https://localhost:8080/#) in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    475         is_causal = True if causal_mask is None and q_len > 1 else False
    476 
--> 477         attn_output = torch.nn.functional.scaled_dot_product_attention(
    478             query_states,
    479             key_states,
    480            dropout_p=self.attention_dropout if self.training else 0.0,
    481            is_causal=is_causal,
    482 )

RuntimeError: The size of tensor a (2078) must match the size of tensor b (1054) at non-singleton dimension 3

It seems the mismatch occurs because the key length (2078) does not match the key length in the mask (1054).

Do you have any insights on why this discrepancy might be occurring during trainer.train() and how to address this? Thanks again for the help!

jrrw10 commented 1 month ago

It seems the mismatch occurs because the key length (2078) does not match the key length in the mask (1054).

I was experimenting in the MistralSdpaAttention class forward() method after discovering the attention mask was the wrong size.

I changed causal_mask to:

causal_mask = attention_mask
if attention_mask is not None:
    # Ensure the mask matches key length
    if causal_mask.shape[-1] != key_states.shape[-2]:
        if causal_mask.shape[-1] < key_states.shape[-2]:
            padding = key_states.shape[-2] - causal_mask.shape[-1]
            causal_mask = torch.nn.functional.pad(causal_mask, (0, padding))
        else:
            causal_mask = causal_mask[:, :, :, :key_states.shape[-2]]

since it was incorrectly sized before. This allowed training to get kicked off, however I am getting strange output:

Step Training Loss Validation Loss
10 9.897600 9.929580
20 9.884000 9.929580
30 9.878300 9.929580
40 9.895400 9.929580
50 9.874400 9.929580

Similarly, if the model is loaded with attn_implementation="flash_attention_2", i.e

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    attn_implementation="flash_attention_2",
    device_map='cuda',
    torch_dtype=torch.bfloat16,)

I did not have to edit anything in the respective class, and training was kicked off with almost identical constant loss. (Typical learning rate among other hyperparameters). Any help/update is much appreciated.

BenjaminBossan commented 1 month ago

Thanks @jrrw10 for this continued investigation. I tried training mistral and did not run into the errors you reported, only the fix with the DynamicCache was required. The loss also improved nicely. Here is the code that I used:

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, Trainer, TrainingArguments
from peft import LoraConfig, PrefixTuningConfig, get_peft_model

model_id = "mistralai/Mistral-7B-v0.1"
# "teknium/OpenHermes-2.5-Mistral-7B" works as well
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map=0,
    torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

config = PrefixTuningConfig(task_type="CAUSAL_LM", num_virtual_tokens=10)
model = get_peft_model(model, config)

def process(samples):
    tokenized = tokenizer(samples["quote"], truncation=True, max_length=128)
    return tokenized

data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(process, batched=True)

trainer = Trainer(
    model=model,
    train_dataset=data["train"],
    args=TrainingArguments(
        num_train_epochs=1,
        per_device_train_batch_size=4,
        bf16=True,
        learning_rate=3e-4,
        logging_steps=10,
        output_dir="/tmp/peft/869",
    ),
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()

This is the loss that I get:

step train loss
10 8.9031
20 8.4677
30 8.4124
40 8.231
50 7.9302
60 7.996
70 8.1211
80 7.7546
90 8.0796
100 7.5609
110 7.67
120 7.3477
130 7.5419
140 7.4619
150 7.1993
160 7.4286
170 7.0095
180 7.1371
190 7.2012
200 6.9664
210 6.9731
220 7.0729
230 7.1143
240 7.1078
250 6.9686
260 6.8469
270 6.7001
280 6.7319
290 7.0498
300 6.8004
310 6.755
320 6.8177
330 6.4536
340 6.7542
350 6.6656
360 6.7039
370 6.8224
380 6.7475
390 6.6785
400 6.6081
410 6.6096
420 6.6462
430 6.9221
440 6.7322
450 6.6007
460 6.8302
470 6.6958
480 6.758
490 6.6558
500 6.4734
510 6.3721
520 6.9516
530 6.4559
540 6.3372
550 6.4975
560 6.484
570 6.4277
580 6.5164
590 6.6368
600 6.2812
610 6.4681
620 6.5265

Any idea what the difference could be?

jrrw10 commented 1 month ago

@BenjaminBossan I've found the issue/s I was experiencing.

The difference between our implementations is that I am using gradient checkpointing like so:

model = AutoModelForCausalLM.from_pretrained(
   model_id,
   device_map=0,
   torch_dtype=torch.bfloat16,
)
model.gradient_checkpointing_enable()

This should be reproducible for you in your example above if you add that line. Here's the error:

[/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py](https://localhost:8080/#) in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    475         is_causal = True if causal_mask is None and q_len > 1 else False
    476 
--> 477         attn_output = torch.nn.functional.scaled_dot_product_attention(
    478             query_states,
    479             key_states,

RuntimeError: The size of tensor a (72) must match the size of tensor b (41) at non-singleton dimension 3

Which is the exact error that I was experiencing before. However, when I manually fixed the attention mask size mismatch like I mentioned before, I was also met with the constant loss problem I showed above. I having problems reproducing this - maybe this was fixed in a recent commit I'm not sure.

In conclusion - I got my implementation to work by adding the DynamicCache fix, and gradient checkpointing currently breaks Prefix Tuning (attention mask size mismatch). Thank you for the help sir!

BenjaminBossan commented 1 month ago

Thanks for digging deeper. Indeed, with gradient checkpointing, there is an issue. IIUC, adjusting the size of the causal mask is not the correct solution though, which could explain the bad losses you see. I printed the layer index, as well as the shapes of key_states and value_states once before and once after they get updated from the cache in this line:

https://github.com/huggingface/transformers/blob/e0d82534cc95b582ab072c1bbc060852ba7f9d51/src/transformers/models/mistral/modeling_mistral.py#L457

0 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
1 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
2 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
3 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
4 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
5 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
6 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
7 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
8 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
9 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
10 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
11 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
12 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
13 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
14 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
15 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
16 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
17 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
18 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
19 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
20 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
21 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
22 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
23 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
24 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
25 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
26 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
27 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
28 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
29 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
30 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
31 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
31 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 72, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 72, 128])

The shapes are all correct except for the last line: We have sequence length of 31 before applying cache and 41 after applying cache, which is expected as we add 10 virtual tokens.

Now let's check the last line. As we can see, it's again layer 31, i.e. it looks like we're in the gradient checkpointing phase. We can see the wrong length of 72 after updating. This is 31+41, i.e. the updated cache is added on top of the states, which IMO is not the correct way of handling this. It should be the same as during the first time we visited this layer, i.e. 31+10. I wonder if gradient checkpointing + cache is generally broken or if we're using it incorrectly.

BenjaminBossan commented 1 month ago

@jrrw10 FYI: https://github.com/huggingface/peft/issues/1962#issuecomment-2273776263. So the workaround with creating a cache object to hold the past_key_values is a bit of a dead end.

jrrw10 commented 1 month ago

@BenjaminBossan Thanks for the detailed explanation of this.

Just to confirm, Prefix Tuning does successfully train with the DynamicCache workaround and gradient checkpointing disabled as of right now - this is sufficient for me and hopefully other users until a rewrite is done to past_key_values. Hopefully this discussion helped!

BenjaminBossan commented 1 month ago

Great that it works for you. I'd just be cautious with this as it's not using transformers as intended. This can be risky because:

hammoudhasan commented 1 month ago

@jrrw10 do you mind sharing the final script here? I'm having one issue where the training loss goes down but the performance is bad when I do test inference - I suspect it's because I don't apply the chat template during training and inference. Do you optimize for the prefix with the chat template or ?

@BenjaminBossan do you know how your provided code could be modified to use chat models where templates needs to be applied?

BenjaminBossan commented 1 month ago

do you know how your provided code could be modified to use chat models where templates needs to be applied?

Did you try calling apply_chat_template in the process function? Also check out the docs in chat templates.

hammoudhasan commented 1 month ago

@BenjaminBossan I got it sorted out (: thank you for the reply. It seems what was pushing me back was a very weird error I got with llama models (e.g TinyLlama/TinyLlama-1.1B-Chat-v1.0) where model(**batch) where batch contains the input_ids, labels and attention_mask values was leading to a weird dimensions error.

huggingface llama RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 32 but got size 4 for tensor number 1 in the list.

Several people have reported having a similar error with no solution yet :-(

BenjaminBossan commented 1 month ago

It seems what was pushing me back was a very weird error I got with llama models (e.g TinyLlama/TinyLlama-1.1B-Chat-v1.0) where model(**batch) where batch contains the input_ids, labels and attention_mask values was leading to a weird dimensions error.

You mean independent of PEFT?

hammoudhasan commented 1 month ago

@BenjaminBossan yep! Seems something related to the modeling files.