Open VecherVhatuX opened 5 days ago
I am encountering a similar issue when training a SentenceTransformer model (bert-base-uncased) using PEFT with PromptTuningConfig. During the forward pass, the model raises a RuntimeError due to a mismatch in tensor dimensions between the token embeddings and the attention mask, specifically during the pooling operation.
Interestingly, switching to Lora or PrefixTuning works without any issues. I'd appreciate any assistance in resolving this problem. @sayakpaul
Thanks for reporting this. I can confirm that I get this error. Next, I also tried bert-base-uncased
without sentence-transformers and this worked for me. Here is the amended script:
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoModel
from peft import get_peft_model, PromptTuningConfig, TaskType, PromptTuningInit
use_sentence_transformer = True # setting this to False makes the script pass
device = 0
model_name = "bert-base-uncased"
# Apply PEFT with PromptTuningConfig
peft_config = PromptTuningConfig(
task_type=TaskType.FEATURE_EXTRACTION,
prompt_tuning_init=PromptTuningInit.RANDOM,
num_virtual_tokens=10,
)
if use_sentence_transformer:
model = SentenceTransformer(model_name)
model._modules["0"].auto_model = get_peft_model(
model._modules["0"].auto_model, peft_config
)
else:
model = AutoModel.from_pretrained(model_name)
model = get_peft_model(model, peft_config)
model = model.to(device)
# Switch to training mode
model.train()
# Generate random input tensors
batch_size = 64
seq_len = 53
random_input_ids = torch.randint(0, 30522, (batch_size, seq_len)).to(device)
random_attention_mask = torch.ones(batch_size, seq_len).to(device)
# Perform forward pass
if use_sentence_transformer:
outputs = model({'input_ids': random_input_ids, 'attention_mask': random_attention_mask})
else:
outputs = model(**{'input_ids': random_input_ids, 'attention_mask': random_attention_mask})
for key, val in outputs.items():
print(key, val.shape)
For use_sentence_transformer = True
, I get the same error:
RuntimeError: The expanded size of the tensor (63) must match the existing size (53) at non-singleton dimension 1. Target sizes: [64, 63, 768]. Tensor sizes: [64, 53, 1]
For use_sentence_transformer = False
, the script passes and prints:
last_hidden_state torch.Size([64, 63, 768])
pooler_output torch.Size([64, 768])
Prompt-tuning works by inserting prompt_embeddings
(shape [1, 10, 768]
in this case, as we have 10 virtual tokens) to each sample. To compensate for this, we also extend the attention mask accordingly:
By stepping through the debugger, I could confirm that this part works correctly. The issue is, however, with the Pooler
part of the sentence transformer model. This part does not "know" about the prompt-tuning, since we only transformed the Transformer
part. Ideally, the Transformer
part could pass its attention mask to the Pooler
to stay consistent but as is, this is not done, hence the mismatch. As I have no experience with sentence transformers, I hope that @tomaarsen could help here.
Worst case, we would probably have to wrap the whole model and call both parts separately, ensuring that the attention mask is extended for the pooler to account for num_virtual_tokens
.
System Info
sentence-transformers
: latest versiontorch
: 2.0+peft
: latest versionWho can help?
@sayakpaul @Be
Information
Tasks
examples
folderReproduction
When attempting to train a BERT-based
SentenceTransformer
model (bert-base-uncased
) using PEFT withPromptTuningConfig
, an error occurs during the forward pass. Specifically, there is a mismatch in the tensor expansion size between the token embeddings and the attention mask during the pooling operation.Steps to Reproduce:
SentenceTransformer
model withbert-base-uncased
.PromptTuningConfig
.(64, 53)
(batch size 64, sequence length 53).RuntimeError
due to a mismatch in tensor dimensions.Reproduction Code:
Error Trace:
Moreover, if I change PromptTuningConfig to Lora or PrefixTuning, it works correctly
Expected behavior
The model should correctly process input sequences of varying lengths without a size mismatch between the token embeddings and the attention mask.