flairNLP / flair

A very simple framework for state-of-the-art Natural Language Processing (NLP)
https://flairnlp.github.io/flair/
Other
13.85k stars 2.09k forks source link

[Bug]: SequenceTagger._all_scores_for_token() function returns incorrect prediction distribution for tokens #3448

Open MdMotahar opened 5 months ago

MdMotahar commented 5 months ago

Describe the bug

The SequenceTagger class has a _all_scores_for_token() function that takes as input a batch of sentences, softmax scores from the tagger and the length of each sentence in the batch. The function calculates the probability distribution over all class labels for each token of each sentence in the batch and returns it. The calculation of probability distribution seems incorrect in this function. Below, I have explained this for a sample of English OntoNotes Corpus.

# import necessary modules

# 1. get the corpus
corpus = flair.datasets.ONTONOTES()

# 2. what label do we want to predict?
label_type = 'ner'

# 3. make the label dictionary from the corpus
label_dict = corpus.make_label_dictionary(label_type=label_type, add_unk=False)

# 4. initialize fine-tuneable transformer embeddings WITH document context
embeddings = TransformerWordEmbeddings(model='xlm-roberta-large',
                                       layers="-1",
                                       subtoken_pooling="first",
                                       fine_tune=True,
                                       use_context=True,
                                       )

# 5. initialize bare-bones sequence tagger (no CRF, no RNN, no reprojection)
tagger = SequenceTagger(hidden_size=256,
                        embeddings=embeddings,
                        tag_dictionary=label_dict,
                        tag_type='ner',
                        use_crf=False,
                        use_rnn=False,
                        reproject_embeddings=False,
                        )

label_name = 'ner'
sentences = corpus.train
mini_batch_size=10
verbose = True

with torch.no_grad():
    Sentence.set_context_for_sentences(cast(List[Sentence], sentences))

    # filter empty sentences
    sentences = [sentence for sentence in sentences if len(sentence) > 0]

    # reverse sort all sequences by their length
    reordered_sentences = sorted(sentences, key=len, reverse=True)

    dataloader = DataLoader(
        dataset=FlairDatapointDataset(reordered_sentences),
        batch_size=mini_batch_size,
    )
    # progress bar for verbosity
    if verbose:
        dataloader = tqdm(dataloader, desc="Batch inference")

    overall_loss = torch.zeros(1, device=flair.device)
    label_count = 0
    for batch in dataloader:
        # stop if all sentences are empty
        if not batch:
            continue

        # get features from forward propagation
        sentence_tensor, lengths = tagger._prepare_tensors(batch)
        features = tagger.forward(sentence_tensor, lengths)

        # remove previously predicted labels of this type
        for sentence in batch:
            sentence.remove_labels(label_name)
        break

We calculate the probability distribution of this batch via the _all_scores_for_token() function. The _all_scores_for_token() function is written as a separate function below for better understanding of the output.

def _all_scores_for_token(sentences: List[Sentence], scores: torch.Tensor, lengths: List[int]):
        """Returns all scores for each tag in tag dictionary."""
        scores = scores.numpy()
        tokens = [token for sentence in sentences for token in sentence]
        print('Number of tokens in batch:',len(tokens))
        prob_all_tags = [
            [
                Label(token, tagger.label_dictionary.get_item_for_index(score_id), score)
                for score_id, score in enumerate(score_dist)
            ]
            for score_dist, token in zip(scores, tokens)
        ]

        print('Length of prob_all_tags:',len(prob_all_tags))

        prob_tags_per_sentence = []
        previous = 0
        for i,length in enumerate(lengths):
            print(f'Length range of Sentence {i}: {previous} to {previous + length}')
            prob_tags_per_sentence.append(prob_all_tags[previous : previous + length])
            previous = length

        return prob_tags_per_sentence

softmax_batch = F.softmax(features, dim=1).cpu()
lengths = [len(sentence) for sentence in batch]
all_tags = _all_scores_for_token(batch, softmax_batch, lengths)

Output:
Number of tokens in batch: 1761
Length of prob_all_tags: 1761
Length range of Sentence 0: 0 to 210
Length range of Sentence 1: 210 to 415
Length range of Sentence 2: 205 to 394
Length range of Sentence 3: 189 to 377
Length range of Sentence 4: 188 to 361
Length range of Sentence 5: 173 to 341
Length range of Sentence 6: 168 to 335
Length range of Sentence 7: 167 to 324
Length range of Sentence 8: 157 to 313
Length range of Sentence 9: 156 to 304

Here the total of number tokens in the batch is 1761. prob_all_tag variable contains the probability distribution for each token. But when they are spliited for each sentence in prob_tags_per_sentence variable, the sentence length range calculation is incorrect which can be seen from the above output. The corrected length range calculation should be -

for i,length in enumerate(lengths):
            print(f'Length range of Sentence {i}: {previous} to {previous + length}')
            prob_tags_per_sentence.append(prob_all_tags[previous : previous + length])
            # previous = length should be previous+=length
            previous += length

To Reproduce

import flair
import flair.datasets
from flair.embeddings import TransformerWordEmbeddings
from flair.models import SequenceTagger
from flair.data import Sentence, Label
from flair.datasets import DataLoader, FlairDatapointDataset
from tqdm import tqdm
from typing import List, cast
import torch
import torch.nn.functional as F

def _all_scores_for_token(sentences: List[Sentence], scores: torch.Tensor, lengths: List[int]):
        """Returns all scores for each tag in tag dictionary."""
        scores = scores.numpy()
        tokens = [token for sentence in sentences for token in sentence]
        print('Number of tokens in batch:',len(tokens))
        prob_all_tags = [
            [
                Label(token, tagger.label_dictionary.get_item_for_index(score_id), score)
                for score_id, score in enumerate(score_dist)
            ]
            for score_dist, token in zip(scores, tokens)
        ]

        print('Length of prob_all_tags:',len(prob_all_tags))

        prob_tags_per_sentence = []
        previous = 0
        for i,length in enumerate(lengths):
            print(f'Length range of Sentence {i}: {previous} to {previous + length}')
            prob_tags_per_sentence.append(prob_all_tags[previous : previous + length])
            previous = length

        return prob_tags_per_sentence

# 1. get the corpus
corpus = flair.datasets.ONTONOTES()

# 2. what label do we want to predict?
label_type = 'ner'

# 3. make the label dictionary from the corpus
label_dict = corpus.make_label_dictionary(label_type=label_type, add_unk=False)

# 4. initialize fine-tuneable transformer embeddings WITH document context
embeddings = TransformerWordEmbeddings(model='xlm-roberta-large',
                                       layers="-1",
                                       subtoken_pooling="first",
                                       fine_tune=True,
                                       use_context=True,
                                       )

# 5. initialize bare-bones sequence tagger (no CRF, no RNN, no reprojection)
tagger = SequenceTagger(hidden_size=256,
                        embeddings=embeddings,
                        tag_dictionary=label_dict,
                        tag_type='ner',
                        use_crf=False,
                        use_rnn=False,
                        reproject_embeddings=False,
                        )

label_name = 'ner'
sentences = corpus.train
mini_batch_size=10
verbose = True

with torch.no_grad():
    Sentence.set_context_for_sentences(cast(List[Sentence], sentences))

    # filter empty sentences
    sentences = [sentence for sentence in sentences if len(sentence) > 0]

    # reverse sort all sequences by their length
    reordered_sentences = sorted(sentences, key=len, reverse=True)

    dataloader = DataLoader(
        dataset=FlairDatapointDataset(reordered_sentences),
        batch_size=mini_batch_size,
    )
    # progress bar for verbosity
    if verbose:
        dataloader = tqdm(dataloader, desc="Batch inference")

    overall_loss = torch.zeros(1, device=flair.device)
    label_count = 0
    for batch in dataloader:
        # stop if all sentences are empty
        if not batch:
            continue

        # get features from forward propagation
        sentence_tensor, lengths = tagger._prepare_tensors(batch)
        features = tagger.forward(sentence_tensor, lengths)

        # remove previously predicted labels of this type
        for sentence in batch:
            sentence.remove_labels(label_name)
        break

softmax_batch = F.softmax(features, dim=1).cpu()
lengths = [len(sentence) for sentence in batch]
all_tags = _all_scores_for_token(batch, softmax_batch, lengths)

Expected behavior

If length is measured correctly in the _all_scores_for_token() function, the output should be -

Number of tokens in batch: 1761
Length of prob_all_tags: 1761
Length range of Sentence 0: 0 to 210
Length range of Sentence 1: 210 to 415
Length range of Sentence 2: 415 to 604
Length range of Sentence 3: 604 to 792
Length range of Sentence 4: 792 to 965
Length range of Sentence 5: 965 to 1133
Length range of Sentence 6: 1133 to 1300
Length range of Sentence 7: 1300 to 1457
Length range of Sentence 8: 1457 to 1613
Length range of Sentence 9: 1613 to 1761

Logs and Stack traces

No response

Screenshots

No response

Additional Context

No response

Environment

Versions:

Flair

0.13.1

Pytorch

2.2.1+cu121

Transformers

4.40.1

GPU

True