huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.01k stars 27.01k forks source link

Issues attempting to implement P-TuningV2 with huggingface's BART #20905

Closed maxrousseau closed 1 year ago

maxrousseau commented 1 year ago

@patrickvonplaten Hello, I am trying to implement P-Tuningv2 with BART using huggingface's transformers v4.25.1 (P-TuningV2 official repo). However, when I try to train the model I get the following error:

[/usr/local/lib/python3.8/dist-packages/transformers/models/bart/modeling_bart.py](https://localhost:8080/#) in forward(self, hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, output_attentions)
    238         if attention_mask is not None:
    239             if attention_mask.size() != (bsz, 1, tgt_len, src_len):
--> 240                 raise ValueError(
    241                     f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
    242                 )

ValueError: Attention mask should be of size (4, 1, 648, 648), but is torch.Size([4, 1, 652, 652])

Any ideas where the issue is coming from or how to resolve this? I am a little unfamiliar with the codebase so any help will be greatly appreciated.

Thanks,

Here's the code I'm using to run the model:

import torch
from torch import nn
from transformers import BartPretrainedModel, BartConfig, BartModel

import copy
import math
import random
import warnings

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss

from transformers.modeling_outputs import Seq2SeqLMOutput

def shift_tokens_right(
    input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    # replace possible -100 values in labels by `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids

class PrefixEncoder(torch.nn.Module):
    r"""
    The torch.nn model to encode the prefix
    Input shape: (batch-size, prefix-length)
    Output shape: (batch-size, prefix-length, 2*layers*hidden)
    """

    def __init__(self, config):
        super().__init__()
        self.prefix_projection = config.prefix_projection
        if self.prefix_projection:
            # Use a two-layer MLP to encode the prefix
            self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
            self.trans = torch.nn.Sequential(
                torch.nn.Linear(config.hidden_size, config.prefix_hidden_size),
                torch.nn.Tanh(),
                torch.nn.Linear(
                    config.prefix_hidden_size,
                    config.num_hidden_layers * 2 * config.hidden_size,
                ),
            )
        else:
            self.embedding = torch.nn.Embedding(
                config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size
            )

    def forward(self, prefix: torch.Tensor):
        if self.prefix_projection:
            prefix_tokens = self.embedding(prefix)
            past_key_values = self.trans(prefix_tokens)
        else:
            past_key_values = self.embedding(prefix)
        return past_key_values

class PrefixBartForConditionalGeneration(BartPretrainedModel):
    base_model_prefix = "model"
    _keys_to_ignore_on_load_missing = [
        r"final_logits_bias",
        r"lm_head.weight",
        "encoder.embed_tokens.weight",
        "decoder.embed_tokens.weight",
    ]

    def __init__(self, config: BartConfig):
        # MAX - testing the config default values from (https://github.com/THUDM/P-tuning-v2/blob/main/arguments.py)
        config.pre_seq_len = 4
        config.hidden_dropout_prob = 0.1
        config.prefix_hidden_size = 512
        config.prefix_projection = False

        super().__init__(config)

        # MAX :: get the layer, embedding and heads to generate the prefix
        self.pre_seq_len = config.pre_seq_len
        self.n_layer = config.num_hidden_layers
        self.n_head = config.num_attention_heads
        self.n_embd = (
            config.hidden_size // config.num_attention_heads
        )  # MAX - here we change the embed dims..

        self.model = BartModel(config)
        self.register_buffer(
            "final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))
        )
        self.lm_head = nn.Linear(
            config.d_model, self.model.shared.num_embeddings, bias=False
        )

        # MAX :: add the prefix encoder/tokens and dropout for the prefixes
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.prefix_encoder = PrefixEncoder(config)
        self.prefix_tokens = torch.arange(self.pre_seq_len).long()

        # MAX :: freeze the model parameters
        for param in self.model.parameters():
            param.requires_grad = False

        # Initialize weights and apply final processing
        self.post_init()

    # MAX :: modify and adapt for bart
    def get_prompt(self, batch_size):
        prefix_tokens = (
            self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.model.device)
        )
        past_key_values = self.prefix_encoder(prefix_tokens)
        bsz, seqlen, _ = past_key_values.shape
        past_key_values = past_key_values.view(
            bsz, seqlen, self.n_layer * 2, self.n_head, self.n_embd
        )
        past_key_values = self.dropout(past_key_values)
        past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
        return past_key_values

    def get_encoder(self):
        return self.model.get_encoder()

    def get_decoder(self):
        return self.model.get_decoder()

    def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
        new_embeddings = super().resize_token_embeddings(new_num_tokens)
        self._resize_final_logits_bias(new_num_tokens)
        return new_embeddings

    def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
        old_num_tokens = self.final_logits_bias.shape[-1]
        if new_num_tokens <= old_num_tokens:
            new_bias = self.final_logits_bias[:, :new_num_tokens]
        else:
            extra_bias = torch.zeros(
                (1, new_num_tokens - old_num_tokens),
                device=self.final_logits_bias.device,
            )
            new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
        self.register_buffer("final_logits_bias", new_bias)

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        Returns:
        """
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # MAX-NOTE :: run the prefix layer
        batch_size = input_ids.shape[0]
        past_key_values = self.get_prompt(batch_size=batch_size)
        prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(
            self.model.device
        )
        attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
        print("encoder mask: {}".format(attention_mask.size()))
        # BUG attention_mask is changed but no the size of the hidden_states and and the key_states (past_key_value[0])?

        if labels is not None:
            if use_cache:
                logger.warning(
                    "The `use_cache` argument is changed to `False` since `labels` is provided."
                )
            use_cache = False
            if decoder_input_ids is None and decoder_inputs_embeds is None:
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )

        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,  # MAX-NOTE :: unlike bert this did not need to be added here?
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        lm_logits = self.lm_head(outputs[0])
        lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)

        masked_lm_loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            masked_lm_loss = loss_fct(
                lm_logits.view(-1, self.config.vocab_size), labels.view(-1)
            )

        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return (
                ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
            )

        return Seq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )

    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        past=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # cut decoder_input_ids if past is used
        if past is not None:
            decoder_input_ids = decoder_input_ids[:, -1:]

        return {
            "input_ids": None,  # encoder_outputs is defined. input_ids not needed
            "encoder_outputs": encoder_outputs,
            "past_key_values": past,
            "decoder_input_ids": decoder_input_ids,
            "attention_mask": attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,  # change this to avoid caching (presumably for debugging)
        }

    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return shift_tokens_right(
            labels, self.config.pad_token_id, self.config.decoder_start_token_id
        )

    @staticmethod
    def _reorder_cache(past, beam_idx):
        reordered_past = ()
        for layer_past in past:
            # cached cross_attention states don't have to be reordered -> they are always the same
            reordered_past += (
                tuple(
                    past_state.index_select(0, beam_idx)
                    for past_state in layer_past[:2]
                )
                + layer_past[2:],
            )
        return reordered_past
sgugger commented 1 year ago

cc @ArthurZucker

maxrousseau commented 1 year ago

Hello @patrickvonplaten @ArthurZucker,

I wrote a simple test case to reproduce the error I am getting for the model I am trying to implement using a few examples from SQuAD.

1. Loading the dataset

from datasets import Dataset

def formatToMI(dataset):
    """take a squad-like qa dataset and transform into MLM format"""
    masked_strings = []
    full_strings = []
    qa_strings = []
    answer_strings = []

    for i in range(len(dataset["question"])):
        question = dataset["question"][i]
        answer = dataset["answers"][i]["text"][0]
        context = dataset["context"][i]

        masked_strings.append(
            "Question: {} Answer: <mask>. Context: {}".format(question, context)
        )
        full_strings.append(
            "Question: {} Answer: {}. Context: {}".format(question, answer, context)
        )
        qa_strings.append("Question: {} Answer: {}.".format(question, answer))
        answer_strings.append(answer)

    return {
        "masked_strings": masked_strings,
        "full_strings": full_strings,
        "qa_strings": qa_strings,
        "answer_strings": answer_strings,
        "id": dataset["id"],
    }

def loadSquadMI(n=None):
    """create a dataloader for SQuAD"""
    from datasets import load_dataset
    raw_datasets = load_dataset("squad")

    if n is not None:
        squad_subset = formatToMI(raw_datasets["train"][:n])
        return squad_subset
    else:
        return 0

samples = loadSquadMI(n=100)
tiny_squad = Dataset.from_dict(samples)

2. Creating the dataloader

from transformers import AutoTokenizer, BartForConditionalGeneration, DataCollatorForSeq2Seq
import torch
from torch.utils.data import DataLoader

# initialize BART and PrefixBART for MI
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
examples = tiny_squad
prefixbart_model = PrefixBartForConditionalGeneration.from_pretrained("facebook/bart-base")
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=prefixbart_model,
    label_pad_token_id=-100,
    pad_to_multiple_of=8,
)

# preprocessing
def training_preprocessing(examples):
    """examples have all three types of string"""
    padding = "max_length"
    model_inputs = tokenizer(
        examples["masked_strings"],
        max_length=384,
        padding=padding,
        truncation=False,
    )
    labels = tokenizer(
        text_target=examples["qa_strings"],
        max_length=128,
        padding=padding,
        truncation=True,
    )
    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
    # padding in the loss.
    if padding == "max_length":
        labels["input_ids"] = [
            [(l if l != tokenizer.pad_token_id else -100) for l in label]
            for label in labels["input_ids"]
        ]
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

proc_train_dataset = examples.map(
                training_preprocessing,
                batched=True,
                remove_columns=examples.column_names,
)

train_tensor = proc_train_dataset
train_tensor.set_format("torch")

train_dataloader = DataLoader(
                train_tensor,
                shuffle=True,
                collate_fn=data_collator,
                batch_size=4,
                num_workers=0,
)

3. Test: a single forward pass

With BART : successful

bart_model.train()
batch = next(iter(train_dataloader))
outputs = bart_model(**batch)
loss = outputs.loss
print(loss)

Output: tensor(0.8271, grad_fn=<NllLossBackward0>)

With PrefixBART : failure (same error as above)

prefixbart_model.train()
batch = next(iter(train_dataloader))
outputs = prefixbart_model(**batch)
loss = outputs.loss
print(loss)

Output

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-26-ebc93e8e099a>](https://localhost:8080/#) in <module>
      3 prefixbart_model.train()
      4 batch = next(iter(train_dataloader))
----> 5 outputs = prefixbart_model(**batch)
      6 loss = outputs.loss
      7 print(loss)

9 frames
[/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

[<ipython-input-5-71e56dfc61a6>](https://localhost:8080/#) in forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
    211                 )
    212 
--> 213         outputs = self.model(
    214             input_ids,
    215             attention_mask=attention_mask,

[/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.8/dist-packages/transformers/models/bart/modeling_bart.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
   1231 
   1232         if encoder_outputs is None:
-> 1233             encoder_outputs = self.encoder(
   1234                 input_ids=input_ids,
   1235                 attention_mask=attention_mask,

[/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.8/dist-packages/transformers/models/bart/modeling_bart.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict)
    848                     )
    849                 else:
--> 850                     layer_outputs = encoder_layer(
    851                         hidden_states,
    852                         attention_mask,

[/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.8/dist-packages/transformers/models/bart/modeling_bart.py](https://localhost:8080/#) in forward(self, hidden_states, attention_mask, layer_head_mask, output_attentions)
    323         """
    324         residual = hidden_states
--> 325         hidden_states, attn_weights, _ = self.self_attn(
    326             hidden_states=hidden_states,
    327             attention_mask=attention_mask,

[/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.8/dist-packages/transformers/models/bart/modeling_bart.py](https://localhost:8080/#) in forward(self, hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, output_attentions)
    238         if attention_mask is not None:
    239             if attention_mask.size() != (bsz, 1, tgt_len, src_len):
--> 240                 raise ValueError(
    241                     f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
    242                 )

ValueError: Attention mask should be of size (4, 1, 384, 384), but is torch.Size([4, 1, 388, 388])
maxrousseau commented 1 year ago

Hello again @patrickvonplaten @ArthurZucker,

I just found out about adapter-transformers which implements prefix-tuning for BART on which P-TuningV2 is based. Maybe this issue can be closed?

ArthurZucker commented 1 year ago

Hey! Cool that you found something that works for you! The issue might just have been from a config parameter defining the hidden_size

maxrousseau commented 1 year ago

Hello, thank you for replying. I will try out the modified config and see if it resolves the issue.

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.