zihangdai / xlnet

XLNet: Generalized Autoregressive Pretraining for Language Understanding
Apache License 2.0
6.18k stars 1.18k forks source link

xlnet consistantly underperformed Bert on language modeling #208

Open XuhuiZhou opened 5 years ago

XuhuiZhou commented 5 years ago

Hi, I am using Xlnet as a language model with code provided by HuggingFace PyTorch-transformers. However, the xlnet consistantly underperformed Bert in our experiment. Considering it's advanced design, we are curious how could that happen. For example, we test their ability of coreference resolution on Winograd Schema Challenge dataset. An example of the dataset would be:

The trophy doesn't fit into the brown suitcase because the trophy is too large. The trophy doesn't fit into the brown suitcase because the suitcase is too large.

And we let the model to choose the corret one by calculating the perplexity of the sentence. In the end, we got the result: Bert-large: 62% acc Xlnet-base: 54.4% acc Xlnet-large: 63.6% acc So from my point of view, Xlnet-base should be compared to Bert-large since they have similar parameter size. Furthermore, we have done experiments on other test datasets, like SWAG, and saw the same phenomenon. Any thoughts on this problem would be appreciated :)

Code:

import torch
#from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer,XLNetConfig
import numpy as np
# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
import logging
logging.basicConfig(level=logging.INFO)

def xlnet_score(text, model, tokenizer):
    tokenized_text = tokenizer.tokenize(text)
    # text = "[CLS] Stir the mixture until it is done [SEP]"
    sentence_prob = 0
    #Sprint(len(tokenized_text))
    for masked_index in range(0,len(tokenized_text)):
        # Mask a token that we will try to predict back with `BertForMaskedLM`
        masked_word = tokenized_text[masked_index]
        if masked_word!= "<sep>":
            masked_word = tokenized_text[masked_index]
            tokenized_text[masked_index] = '<mask>'
            input_ids = torch.tensor(tokenizer.convert_tokens_to_ids(tokenized_text)).unsqueeze(0)
            index = torch.tensor(tokenizer.convert_tokens_to_ids(masked_word))

            perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
            perm_mask[:, :, masked_index] = 1.0  # Previous tokens don't see last token
            target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float)  # Shape [1, 1, seq_length] => let's predict one token
            target_mapping[0, 0, masked_index] = 1.0  # Our first (and only) prediction will be the last token of the sequence (the masked token)

            input_ids = input_ids.to('cuda')
            perm_mask = perm_mask.to('cuda')
            target_mapping = target_mapping.to('cuda')
            index = index.to('cuda')

            with torch.no_grad():
                outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping, labels = index)
            next_token_logits = outputs[0]
            length = len(tokenized_text)
            # predict_list = predictions[0, masked_index]
            sentence_prob += next_token_logits.item()
            tokenized_text[masked_index] = masked_word
    return sentence_prob/(length)

tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased')
model.to('cuda')
model.eval()

with open("wsc.txt", "r", encoding= 'utf-8') as f:
    file = f.readlines()

num = len(file)
count = 0
curr = 0
label_list = ['A','B']
for i in file:
    label, sentence_1, sentence_2 = i.split("\001")
    prob_1 = xlnet_score(sentence_1, model=model, tokenizer=tokenizer)
    prob_2 = xlnet_score(sentence_2, model=model, tokenizer=tokenizer)
    answer = min(prob_1, prob_2)
    #print(prob_1, prob_2, prob_3, prob_4)
    index = [prob_1, prob_2].index(answer)
    print(label, label_list[index])
    if label==label_list[index]:
        count+=1
    curr += 1
    print (count, curr, count/curr)
print (count/num)
kimiyoung commented 5 years ago

XLNet-Large has the same number of parameters as BERT-Large, while XLNet-Base has the same number of parameters as BERT-Base. I haven't looked at your code, though.

jihun-hong commented 5 years ago

Maybe there are problems with huggingface's implementation of pytorch-transformers. I saw similar issues in their repository, so maybe you could refer to that.