huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.2k stars 26.83k forks source link

RuntimeError: The size of tensor a (1024) must match the size of tensor b (1025) at non-singleton dimension 3 #11033

Closed yananchen1989 closed 3 years ago

yananchen1989 commented 3 years ago

Here I try to use gpt2 to generation the text under the prompt text. I have several datasets, some of them, such as AG_NEWS and POP_NEWS, are made of short sentences while when I use YAHOO_NEWS, consisting of longer sentences, the error came out. Anything to modify for my codes? Thanks.

from transformers import (
    CTRLLMHeadModel,
    CTRLTokenizer,
    GPT2LMHeadModel,
    GPT2Tokenizer,
    OpenAIGPTLMHeadModel,
    OpenAIGPTTokenizer,
    TransfoXLLMHeadModel,
    TransfoXLTokenizer,
    XLMTokenizer,
    XLMWithLMHeadModel,
    XLNetLMHeadModel,
    XLNetTokenizer,
)

class generation():
    def __init__(self, model_name='gpt2',num_return_sequences=1):
        self.model_name = model_name
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.MODEL_CLASSES = {
                    "gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
                    "ctrl": (CTRLLMHeadModel, CTRLTokenizer),
                    "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
                    "xlnet-base-cased": (XLNetLMHeadModel, XLNetTokenizer),
                    "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
                    "xlm": (XLMWithLMHeadModel, XLMTokenizer),
                }
        self.length = 100
        self.k = 0
        self.p = 0.9
        self.num_return_sequences = num_return_sequences
        self.model_class, self.tokenizer_class = self.MODEL_CLASSES[self.model_name]
        self.tokenizer = self.tokenizer_class.from_pretrained(self.model_name)
        self.model = self.model_class.from_pretrained(self.model_name)
        self.model.to(self.device)
        if self.model_name == "xlnet-base-cased":
            self.p=0.95
            self.k=60

        self.length = self.adjust_length_to_model(self.length, max_sequence_length=self.model.config.max_position_embeddings)

        if self.model_name == 'ctrl':
            self.temperature = 0.3
            self.repetition_penalty = 1.2
        else:
            self.temperature = 1.0
            self.repetition_penalty = 1.0

    def adjust_length_to_model(self, length, max_sequence_length):
        if length < 0 and max_sequence_length > 0:
            length = max_sequence_length
        elif 0 < max_sequence_length < length:
            length = max_sequence_length  # No generation bigger than model size
        elif length < 0:
            length = 1000  # avoid infinite loop
        return length

    def ctrl_label2prefix(self, label):
        # https://github.com/salesforce/ctrl/blob/master/control_codes.py
        '''
        'Pregnancy Christianity Explain Fitness Saving Ask Ass Joke Questions Thoughts Retail 
        Feminism Writing Atheism Netflix Computing Opinion Alone Funny Gaming Human India Joker Diet 
        Legal Norman Tip Weight Movies Running Science Horror Confession Finance Politics Scary Support 
        Technologies Teenage Event Learned Notion Wikipedia Books Extract Confessions Conspiracy Links 
        Narcissus Relationship Relationships Reviews News Translation multilingual'
        '''
        return 'News'

        if label in ('Sci/Tech', 'tech'):
            return 'Technologies'
        elif label in ('politics'):
            return 'Politics'
        elif label in ('Sports', 'sport'):
            return 'Fitness'
        else:
            return 'News'

    def augment(self, prompt_text):
        if self.model_name == 'ctrl':
            prefix = 'News '
        else:
            prefix = ''
        encoded_prompt = self.tokenizer.encode(prefix  + prompt_text, add_special_tokens=False, return_tensors="pt")

        encoded_prompt = encoded_prompt.to(self.device)

        if encoded_prompt.size()[-1] == 0:
            input_ids = None
        else:
            input_ids = encoded_prompt

        output_sequences = self.model.generate(
            input_ids=input_ids,
            max_length= self.length + len(encoded_prompt[0]),
            temperature=self.temperature,
            top_k=self.k,
            top_p=self.p,
            repetition_penalty=self.repetition_penalty,
            do_sample=True,
            num_return_sequences=self.num_return_sequences,
        )

        # Decode text
        text_generated = self.tokenizer.decode(output_sequences[0][len(encoded_prompt[0]):], clean_up_tokenization_spaces=True)
        return text_generated

# unit test
'''
augmentor = generation('gpt2')

prompt_text = "Microsoft has said it will replace more than 14 million power cables for its Xbox consoles due to safety concerns."
prompt_text = "Versace art portfolio up for sale The art collection of murdered fashion designer Gianni Versace could fetch \
up to £9m ($17m) when it is auctioned in New York and \
London later this year. <eod> </s> <eos>"

augmentor.augment(prompt_text)
'''

ERROR information:

File "baseline_classifier.py", line 45, in run_benchmark ds.df_train['content_aug'] = ds.df_train['content'].map(lambda x: augmentor.augment(x)) File "/workspace/.conda/miniconda/lib/python3.7/site-packages/pandas/core/series.py", line 3382, in map arg, na_action=na_action) File "/workspace/.conda/miniconda/lib/python3.7/site-packages/pandas/core/base.py", line 1218, in _map_values new_values = map_f(values, mapper) File "pandas/_libs/lib.pyx", line 2217, in pandas._libs.lib.map_infer File "baseline_classifier.py", line 45, in ds.df_train['content_aug'] = ds.df_train['content'].map(lambda x: augmentor.augment(x)) File "/workspace/user-workspace/topic_classification_augmentation/aug_generation.py", line 110, in augment num_return_sequences=self.num_return_sequences, File "/workspace/.conda/miniconda/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context return func(args, kwargs) File "/workspace/.conda/miniconda/lib/python3.7/site-packages/transformers/generation_utils.py", line 1019, in generate model_kwargs, File "/workspace/.conda/miniconda/lib/python3.7/site-packages/transformers/generation_utils.py", line 1486, in sample output_hidden_states=output_hidden_states, File "/workspace/.conda/miniconda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(input, kwargs) File "/workspace/.conda/miniconda/lib/python3.7/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 917, in forward return_dict=return_dict, File "/workspace/.conda/miniconda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(*input, *kwargs) File "/workspace/.conda/miniconda/lib/python3.7/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 760, in forward output_attentions=output_attentions, File "/workspace/.conda/miniconda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(input, kwargs) File "/workspace/.conda/miniconda/lib/python3.7/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 296, in forward output_attentions=output_attentions, File "/workspace/.conda/miniconda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(*input, **kwargs) File "/workspace/.conda/miniconda/lib/python3.7/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 241, in forward attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions) File "/workspace/.conda/miniconda/lib/python3.7/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 176, in _attn w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype)) RuntimeError: The size of tensor a (1024) must match the size of tensor b (1025) at non-singleton dimension 3

github-actions[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

MehwishFatimah commented 3 years ago

I am also having the same issue and I don't have input/output longer than 1024/.

cdeeran commented 2 years ago

Experiencing the same problem using the 2048 size with GPT-J.