huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.24k stars 26.34k forks source link

TokenClassification Pipeline not aggregating entities correctly #23322

Closed neilkimn closed 1 year ago

neilkimn commented 1 year ago

System Info

I am running transformers==4.27.3 but I believe the issue persists in the latest version as the issue at hand is specific to the gather_pre_entities function https://github.com/huggingface/transformers/blob/v4.27.3/src/transformers/pipelines/token_classification.py#L281.

Who can help?

Tagging contributors who have committed to the TokenClassification Pipeline lately: @luccailliau @Narsil @sgugger

Information

Tasks

Reproduction

The code below shows the main details of the issue. I want to use the aggregation strategies of the TokenClassification Pipeline, and due to using the LayoutLM model and tokenizer, the aggregation of subwords falls back to the heuristic implemented for the gather_pre_entities function of TokenClassificationPipeline. This should be fine, however I am experiencing cases where tokens are not properly merged, as shown in the example output below. In the original sentence string, I have a bunch of words, where the following snippet is of interest: "... I alt DKK inkl. moms 5.975,74 Betalingsbetingelser: KONTANT ...". The model correctly predicts the entity, TOTAL, but is missing the last digit, 4, which gets grouped to its own TOTAL-entity prediction.

# Omitted a bunch of boilerplate code including model definition, setting up dataset, etc.

pipe = TokenClassificationPipeline(model=model, tokenizer=tokenizer)

sample_output = model.forward(
    input_ids=sample_input["input_ids"].type(torch.long),
    bbox=sample_input["bbox"].type(torch.int32),
    image=torch.stack(sample_input["image"]),
    attention_mask=sample_input["attention_mask"],
)

sample_scores = sample_output["logits"][0].cpu().detach().numpy()

pre_entities = pipe.gather_pre_entities(
    sentence = " ".join(dataset["test"][0]["words"]),
    input_ids=sample_input["input_ids"][0],
    scores = sample_scores,
    offset_mapping=sample_input["offset_mapping"][0],
    special_tokens_mask=sample_input["special_tokens_mask"][0].cpu().detach().numpy(),
    aggregation_strategy="simple"
) # throws UserWarning: "Tokenizer does not support real words, using fallback heuristic"

grouped_entities = pipe.aggregate(pre_entities, aggregation_strategy="first")

for ent in grouped_entities:
    print(ent)

>>> {'entity_group': 'O', 'score': 11.709729, 'word': [... long sentence ...], 'start': 0, 'end': 11}
>>> [... some other predicted entities ...]
>>> {'entity_group': 'TOTAL', 'score': 8.98903, 'word': '5.975,7', 'start': 0, 'end': 7}
>>> {'entity_group': 'TOTAL', 'score': 6.8310637, 'word': '4', 'start': 7, 'end': 8}
>>> {'entity_group': 'O', 'score': 11.587039, 'word': [... long sentence ...], 'start': 0, 'end': 23}

Expected behavior

Diving into the gather_pre_entities function, I see that the heuristic uses the is_subword boolean to determine how subwords should be aggregated to a combined word, with a corresponding, merged entity. Specifically, the heuristic uses the following rule is_subword = start_ind > 0 and " " not in sentence[start_ind - 1 : start_ind + 1], where if I comment out the second part of the conditional, results in the entity being correctly merged, i.e. {'entity_group': 'TOTAL', 'score': 8.98903, 'word': '5.975,74', 'start': 0, 'end': 8}.

else:
    # This is a fallback heuristic. This will fail most likely on any kind of text + punctuation mixtures that will be considered "words". Non word aware models cannot do better than this unfortunately.
    if aggregation_strategy in {
        AggregationStrategy.FIRST,
        AggregationStrategy.AVERAGE,
        AggregationStrategy.MAX,
    }:
        warnings.warn("Tokenizer does not support real words, using fallback heuristic", UserWarning)
        is_subword = start_ind > 0 and " " not in sentence[start_ind - 1 : start_ind + 1]

Since the start_ind of the subword is relative to the entire, original word that the subword is part of composing, why does the heuristic then depend on indexing into the entire sentence string? These indices, coming from the offset_mapping will always be relative to the word and most often range from 0-10 and so forth, depending on the word length. Without understanding the full reason behind why " " would constitute a subword, I am certain that this must be a bug. Even if the start and end indices from offset_mapping were relative to the entire sentence, how could you then determine when a new word is starting?

luccailliau commented 1 year ago

Hello @neilkimn,

I haven't looked at the code in detail but I think the error is not from the pipeline itself but from a wrong prediction of the model. I guess you are using IOB format for your labels and maybe the last digit was predicted with B-TOTAL and not I-TOTAL which end up creating a new entity for only one digit.

This behavior is common, which is why there are different aggregation strategies. Changing your aggregation strategy from simple to first to calculate pre_entities should solve your problem.

neilkimn commented 1 year ago

Hi @luccailliau, thanks for the swift reply. You're right about the prediction for TOTAL isn't comprised of the correct IOB format. Here's the output when using no aggregation strategy:

{'entity': 'B-TOTAL', 'score': 0.99910754, 'index': 280, 'word': '▁5.', 'start': 0, 'end': 2}
{'entity': 'B-TOTAL', 'score': 0.9981998, 'index': 281, 'word': '97', 'start': 2, 'end': 4}
{'entity': 'B-TOTAL', 'score': 0.9978011, 'index': 282, 'word': '5,7', 'start': 4, 'end': 7}
{'entity': 'B-TOTAL', 'score': 0.9913623, 'index': 283, 'word': '4', 'start': 7, 'end': 8}

Applying aggregation strategies yields:

# simple
{'entity_group': 'TOTAL', 'score': 0.99910754, 'word': '5.', 'start': 0, 'end': 2}
{'entity_group': 'TOTAL', 'score': 0.9981998, 'word': '97', 'start': 2, 'end': 4}
{'entity_group': 'TOTAL', 'score': 0.9978011, 'word': '5,7', 'start': 4, 'end': 7}
{'entity_group': 'TOTAL', 'score': 0.9913623, 'word': '4', 'start': 7, 'end': 8}

# first
{'entity_group': 'TOTAL', 'score': 0.99910754, 'word': '5.975,7', 'start': 0, 'end': 7}
{'entity_group': 'TOTAL', 'score': 0.9913623, 'word': '4', 'start': 7, 'end': 8}

# average
{'entity_group': 'TOTAL', 'score': 0.99836946, 'word': '5.975,7', 'start': 0, 'end': 7}
{'entity_group': 'TOTAL', 'score': 0.9913623, 'word': '4', 'start': 7, 'end': 8}

# max
{'entity_group': 'TOTAL', 'score': 0.99910754, 'word': '5.975,7', 'start': 0, 'end': 7}
{'entity_group': 'TOTAL', 'score': 0.9913623, 'word': '4', 'start': 7, 'end': 8}

And I am confident the issue is due to the heuristic using the start_ind of the subword offset_mapping and subsequently indexing into sentence. Backtracking through the callstack, I could verify that the sentence variable contained the full input sentence, and it is only coincidental that " " not in sentence[start_ind - 1 : start_ind + 1] yields False, ultimately setting is_subword = False for the word '4', even though it is a subword.

luccailliau commented 1 year ago

@neilkimn,

You're using sentence = " ".join(dataset["test"][0]["words"]) to generate a sentence from a list of words (or subwords). This is not a problem but the original offset_mapping with offset_mapping=sample_input["offset_mapping"][0] won't match with the sentence created with " ".join(). I am pretty sure that something like this is happening: image I think the easiest solution for your problem is a loop that merges entities if entities[i]["end"] == entities[i+1]["start"] or (not a beautiful solution) tokenize the initial sentence to generate tokens, then create a new sentence with " ".join(tokens) and finally tokenize this new sentence to have offset_mapping aligned with sentence.

neilkimn commented 1 year ago

Thanks for clarifying @luccailliau, that explains why the offset_mapping is different for my example. I guess the issue is propagated from how the LayoutXLMProcessor calls LayoutXLMTokenizerFast which I am using. Supplying the processor with both the tokens joined together as well as their original split representation resolves it.