WanzhengZhu / GRUEN

GRUEN for Evaluating Linguistic Quality of Generated Text (EMNLP 2020 Findings)
https://arxiv.org/pdf/2010.02498.pdf
MIT License
27 stars 12 forks source link

batched inference for grammatical score #4

Open Jack000 opened 2 years ago

Jack000 commented 2 years ago

I noticed that the lm_score code processes a single sentence at a time. This is pretty slow if you're processing a large amount of data. I wrote a batched version, though it's a bit ugly. This increases processing speed by about 8x on a single 3090

import torch.nn.functional as F

def get_lm_score(sentences, batch_tokens=42000):

    def score_batch(batch, tokenizer, model):
        inputs = tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(device)
        batch_scores = []

        with torch.no_grad():
            labels = inputs["input_ids"].clone()
            labels[inputs["input_ids"] == tokenizer.pad_token_id] = -100
            out = model(input_ids=inputs["input_ids"], labels=labels, attention_mask=inputs["attention_mask"], token_type_ids=inputs["token_type_ids"])
            logits = out['logits']

            for j in range(labels.shape[0]):
                loss = F.cross_entropy(logits[j].view(-1, tokenizer.vocab_size), labels[j].view(-1))
                batch_scores.append(math.exp(loss.item()))

        return batch_scores

    model_name = 'bert-base-cased'
    model = BertForMaskedLM.from_pretrained(model_name).to(device)
    model.eval()
    tokenizer = BertTokenizerFast.from_pretrained(model_name)
    lm_score = []

    # sort sentences by length for optimal padding (getting the tokens takes too long so using string length as approximation)
    sentences_flat = []
    for sent in sentences:
        for s in sent:
            sentences_flat.append((s, len(s)))

    sentences_flat.sort(key=lambda x: x[1], reverse=True)

    batches = []

    current_batch_count = 0
    current_batch = []
    for sent in sentences_flat:
        current_batch.append(sent[0])
        current_batch_count += sent[1]
        if current_batch_count > batch_tokens:
            batches.append(current_batch)
            current_batch_count = 0
            current_batch = []

    if len(current_batch) > 0:
        batches.append(current_batch)

    score_dict = {}

    for batch in tqdm(batches):
        batch_score = score_batch(batch, tokenizer, model)
        for j, sent in enumerate(batch):
            score_dict[sent] = batch_score[j]

    for sentence in sentences:
        if len(sentence) == 0:
            lm_score.append(0.0)
            continue
        score_i = 0.0
        for x in sentence:
            if x in score_dict:
                score_i += score_dict[x]
            else:
                score_i += 10000
        score_i /= len(sentence)
        lm_score.append(score_i)
    return lm_score
WanzhengZhu commented 2 years ago

Thank you Jack! Can you request a pull?

Jack000 commented 2 years ago

ah there are two things I'm not sure about.

huggingface seems to have changed their api for the model.forward call - the above code works for transformers 4.20 (the latest one) but not the one in this repo (3.3.1) The code would have to be changed if you want to keep the current transformer version.

the batched code requires a new parameter for either a batch size or number of tokens per batch. This parameter would need to be set depending on how much vram you have. I'm not sure how you'd like to expose this option in your code.

WanzhengZhu commented 2 years ago

Ahhh I see. Thanks for pointing that out. I will check it out