NielsRogge / Transformers-Tutorials

This repository contains demos I made with the Transformers library by HuggingFace.
MIT License
8.69k stars 1.36k forks source link

Inference with nielsr/lilt-xlm-roberta-base #236

Open piegu opened 1 year ago

piegu commented 1 year ago

Hi @NielsRogge,

in your notebook Fine_tune_LiLT_on_a_custom_dataset%2C_in_any_language.ipynb, you do not write the code for inference from a document image.

Can you give it please? I did try to copy the code from "Document AI: LiLT a better language agnostic LayoutLM model" but the fact to use an xlm-roberta and not a roberta seems to complicate the processor instance.

NielsRogge commented 1 year ago

Hi,

To perform inference with LiLT, you don't need a processor, as the model only gets text and corresponding bounding boxes as input. We only need a tokenizer.

Inference can be performed as follows:

from transformers import AutoTokenizer, AutoModelForTokenClassification
from datasets import load_dataset

tokenizer = AutoTokenizer.from_pretrained("nielsr/lilt-xlm-roberta-base")
model = AutoModelForTokenClassification.from_pretrained("nielsr/lilt-xlm-roberta-base")

# load dataset
dataset = load_dataset("nielsr/funsd-iob-original", split="train")

# load an image on which we'd like to predict,
# along with the words + coordinates recognized by an OCR engine
example = self.dataset[0]
image = example["image"]
words = example["words"]
boxes = example["original_bboxes"]

# prepare for the model
width, height = image.size

bbox = [] 
for word, box, label in zip(words, boxes, ner_tags):
    box = normalize_bbox(box, width, height)
    n_word_tokens = len(tokenizer.tokenize(word))
    bbox.extend([box] * n_word_tokens)

# add special tokens
cls_box = sep_box = [0, 0, 0, 0]
bbox = [cls_box] + bbox + [sep_box]

encoding = self.tokenizer(" ".join(words), truncation=True, max_length=512)
sequence_length = len(encoding.input_ids)
# truncate boxes based on length of input ids
bbox = bbox[:sequence_length] 

encoding["bbox"] = bbox

Next you can do a forward pass and decode the predictions:

import torch

# forward pass
with torch.no_grad():
   outputs = model(**encoding)

predicted_class_indices = outputs.logits.argmax(-1)[0].tolist()
# turn into actual class names
predicted_classes = [model.config.id2label[label] for label in predicted_class_indices]

Note: I'm loading a model with a randomly initialized classification head above, hence predictions will be random. You need to load a model which has a fine-tuned classification head.

See also my LayoutLMv1 notebooks regarding turning the token-level predictions into word-level predictions and visualization.

piegu commented 1 year ago

Hi @NielsRogge,

Thank you but you consider in your code that I already have the corresponding bounding boxes as input, but I don't.

I have only the image of a document and I want to get inference with the LiLT model I have finetuned with your nielsr/lilt-xlm-roberta-base.

How to do that?

NielsRogge commented 1 year ago

LiLT, like LayoutLM models, depends on an OCR engine of choice. You'll first need to run the OCR on the image to get a list of words + corresponding boxes.

eschaffn commented 1 year ago

Hi there! I'm having trouble using the above code, with some modifications. I'm using the output of PaddleOCR for the bboxes and words. Here's the code:

import torch

from transformers import AutoTokenizer, AutoModelForTokenClassification
from paddleocr import PaddleOCR,draw_ocr
from PIL import Image

def normalize_bbox(bbox, width, height):
    return [
        int(1000 * (bbox[0] / width)),
        int(1000 * (bbox[1] / height)),
        int(1000 * (bbox[2] / width)),
        int(1000 * (bbox[3] / height)),
    ]

def convert_to_4_value_format(coordinates):
    if not coordinates or len(coordinates) < 4:
        raise ValueError("Input list must contain four coordinate pairs.")

    x_coords = [coord[0] for coord in coordinates]
    y_coords = [coord[1] for coord in coordinates]

    x_min = min(x_coords)
    y_min = min(y_coords)
    x_max = max(x_coords)
    y_max = max(y_coords)

    return [x_min, y_min, x_max, y_max]

ocr = PaddleOCR(lang='ru')
tokenizer = AutoTokenizer.from_pretrained("nielsr/lilt-xlm-roberta-base")
model = AutoModelForTokenClassification.from_pretrained("nielsr/lilt-xlm-roberta-base")

img_path = '/home/ubuntu/9548997E-D883-43A0-9BF2-028C7903BDAF_mw800_s.jpg'

result = ocr.ocr(img_path, cls=True)
# print(len(result))
# print(result[0])
# print(result[0][0])

boxes = []
words = []
image = Image.open(img_path)
for r in result[0]:
    boxes.append(convert_to_4_value_format(r[0]))
    words.append(r[1][0])

width, height = image.size

bbox = [] 
for word, box in zip(words, boxes):
    box = normalize_bbox(box, width, height)
    n_word_tokens = len(tokenizer.tokenize(word))
    bbox.extend([box] * n_word_tokens)

# add special tokens
cls_box = sep_box = [0, 0, 0, 0]
bbox = [cls_box] + bbox + [sep_box]

encoding = tokenizer(" ".join(words), truncation=True, max_length=512)
sequence_length = len(encoding.input_ids)
# truncate boxes based on length of input ids
bbox = bbox[:sequence_length] 

encoding["bbox"] = bbox

print(encoding)

# forward pass
with torch.no_grad():
   outputs = model(**encoding)

predicted_class_indices = outputs.logits.argmax(-1)[0].tolist()
# turn into actual class names
predicted_classes = [model.config.id2label[label] for label in predicted_class_indices]

And here's the error I get:

Traceback (most recent call last):
  File "/home/ubuntu/digital_scope/test.py", line 71, in <module>
    outputs = model(**encoding)
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/transformers/models/lilt/modeling_lilt.py", line 1017, in forward
    outputs = self.lilt(
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/transformers/models/lilt/modeling_lilt.py", line 767, in forward
    input_shape = input_ids.size()
AttributeError: 'list' object has no attribute 'size'

I know this thread is pretty old, but thanks in advance for any assistance and let me know if you need more information!

NielsRogge commented 1 year ago

Hi,

The input_ids need to be PyTorch tensors. However in your case they are still lists of integers. You can fix this by adding return_tensors="pt" to the tokenizer call line:

encoding = tokenizer(" ".join(words), truncation=True, max_length=512, return_tensors="pt")
eschaffn commented 1 year ago

I've actually already done this:

Traceback (most recent call last):
  File "/home/ubuntu/digital_scope/test.py", line 72, in <module>
    outputs = model(**encoding)
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/transformers/models/lilt/modeling_lilt.py", line 1017, in forward
    outputs = self.lilt(
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/transformers/models/lilt/modeling_lilt.py", line 808, in forward
    layout_embedding_output = self.layout_embeddings(bbox=bbox, position_ids=position_ids)
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/transformers/models/lilt/modeling_lilt.py", line 162, in forward
    left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
TypeError: list indices must be integers or slices, not tuple
eschaffn commented 1 year ago

I've actually already done this:

Traceback (most recent call last):
  File "/home/ubuntu/digital_scope/test.py", line 72, in <module>
    outputs = model(**encoding)
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/transformers/models/lilt/modeling_lilt.py", line 1017, in forward
    outputs = self.lilt(
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/transformers/models/lilt/modeling_lilt.py", line 808, in forward
    layout_embedding_output = self.layout_embeddings(bbox=bbox, position_ids=position_ids)
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/transformers/models/lilt/modeling_lilt.py", line 162, in forward
    left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
TypeError: list indices must be integers or slices, not tuple

Casting the bbox into a tensor gives another error:

Traceback (most recent call last):
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/transformers/models/lilt/modeling_lilt.py", line 162, in forward
    left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
IndexError: too many indices for tensor of dimension 2

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ubuntu/digital_scope/test.py", line 72, in <module>
    outputs = model(**encoding)
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/transformers/models/lilt/modeling_lilt.py", line 1017, in forward
    outputs = self.lilt(
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/transformers/models/lilt/modeling_lilt.py", line 808, in forward
    layout_embedding_output = self.layout_embeddings(bbox=bbox, position_ids=position_ids)
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/digital_scope/lib/python3.8/site-packages/transformers/models/lilt/modeling_lilt.py", line 167, in forward
    raise IndexError("The `bbox` coordinate values should be within 0-1000 range.") from e
IndexError: The `bbox` coordinate values should be within 0-1000 range.

Maybe my bounding boxes are formatted incorrectly? I formatted them in 4 value format...here's an example of the encoding object I pass into the model:

{'input_ids': tensor([[     0,   2377,    808,  14853,    698, 220398,  18419,    180,   3167,
            526,    647,    547,    647,  21159,     89,    647, 136713,      6,
          22885,  80072,  45067,    670,    547,    670,   3281,   1068,    647,
           1707,    572,    526,    698,      6,  34423,    303,      6,  86413,
          80844,     19,      9,     14,      5,   7331,   1711,    335,      9,
          31573,      6,    812,  68196,  17829,    915,    197,    238,   1078,
             13,    698,    618,      5,     62,      5,   2789,     59,    384,
              5,    159,    275,   1019,      9,     18,      9,    275,   1019,
              5,  15391,  10169,    159,    177,      9,     14,   1019,      9,
             71,    541,    572,     59,    438,    201,  15158,      9, 149265,
           9559,  20773,   6463,    335,    313, 162727,  75559,  43379, 102047,
           3995,    135,  55023,    227,   2151,  13299,    130,      6,    743,
           5866,   2744,   1125,  17425,     35,    238,    222,   2146,     89,
            265,      9,  44363,  13994,  29412,  24858,      6,  86413,  10306,
             61,   1980,   1214,    407,  34423,    303,   5729,    284,   3995,
              6,  68688,   8618,   4318,    135,  38977,    419, 228814,     59,
          16753, 119039,  37083,    180,    222,  25562,    116,  15158,    563,
              5,   6461,    784,   8618, 156784,    983,    440, 228814,    103,
          16753,    218,   4286,     59,  37083,     89,    129, 227767,    743,
          31084,  12202, 148224,   3007,  33067,     59,  37083,  12105,  56590,
           2578,     84,   1625,  50738,  98329,    130,    663,  10060,   1993,
          21726,   1595,  86196, 127318, 182046,  42689,     12,     87,      5,
          14100,  50524,  28832,   2192, 166153, 209010,    197,  58694,   1196,
         148224,   3007,  33067,     59,  37083,    241,    180,  29815,  97430,
            529,  85979,    476,    222,    652,  97976,    303,    244,      6,
         181596,   9223,  51675,  10927,     29,    547,  13510,  17425,    130,
          15777, 178852,    135,      6,  76202,    841,  39242,  10090,   3995,
            559,    841,   2958,  81118,    328,  60148,    841,   1417,    571,
          34597,    637,  64312,  18036, 208972, 237778,    743,  28535, 148224,
           3007,  33067,     59,  22959,  43346,  61453,    180,  41942,  17933,
          43379,    227,  73932, 148224,   3007,  33067,     59,  37083,  98329,
            130,    436,    197,    238,    650,    983,   2262,   7179, 139425,
         185098,    129,  89703,    130,  46107,    695,    407,   7044,   3806,
          30953, 148224,   3007, 168257,    787,    804,  83438,   4256,  15777,
          25106,   1196,    335,     59,  85280, 178852,    129,   1080,   6737,
          66001,  66261, 148224,   3007,  33067,     59,  37083,     87,    650,
            618,  14973,   1417,    718,   2716,  25495,   8870,    244,  28804,
           1078,  37083,    572,  41753,    989,    736,    303,     84,   1625,
          50738,  98329,    130,    663,  10060,   1993,      5,   1031,  64302,
           5509,   4508,    303,  38140,    637,    313,  15309,   2297, 193145,
         102445,      4,    518, 127435,  22156,    407,    827,  70448,    312,
            407,  34423,    303,      6,  86413,    129,  26012,    105, 209010,
            197,      9,  20931,    227,  12761,   1196, 148224,   3007,  33067,
             59,  37083,     29,  15479, 124234,     84,   1625,  50738,  98329,
            130,  42229,    650,    983,    135,  11923,   1281, 111389,    130,
           1841,  47107,    670,    130,      5,  21076, 133871,  82999,  22959,
          21726,   1595,  12425,   1117,     49,   2744,  76159, 209010,    197,
          44363,  13994,  29412,  24858,   5729,  21898,   2578,  19011,    255,
           1381,  63960,   1781,   8618,     87,      5,    313, 165000,    130,
            829,  68733,  15279, 201877,     59,  34078,    245,   2744,   1125,
          17425,  81118,   8727,  11727,     92,    670,      9,  44363,  13994,
          29412,     29,    698,      6,  86413,  14853,    698, 220398,   1041,
          73233,   1976,    303,    303,  12105,  56590,    407,  34423,    303,
           9504,   3995, 201877,    335,      5,    284,      5,      6, 122141,
           2697,    627,   2875,    141,      2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'bbox': tensor([[  0,   0,   0,   0],
        [373,  93, 460, 125],
        [373,  93, 460, 125],
        ...,
        [951, 986, 997, 998],
        [951, 986, 997, 998],
        [  0,   0,   0,   0]])}
NielsRogge commented 1 year ago

Your boxes need to be a tensor of shape (batch_size, seq_len, 4), and they need to be normalized by the size of the image:

def normalize_bbox(bbox, width, height):
    return [
        int(1000 * (bbox[0] / width)),
        int(1000 * (bbox[1] / height)),
        int(1000 * (bbox[2] / width)),
        int(1000 * (bbox[3] / height)),
    ]

In your case, it looks like the boxes are still a 2D tensor instead of 3D.

eschaffn commented 12 months ago

Thanks, I fixed that as well. I'm now wondering about this line here:

encoding = tokenizer(" ".join(words), truncation=True, max_length=512, return_tensors="pt")

Joining all of the input words with a whitespace causes inconsistent tokenization it seems and it makes it difficult to visualize the output, since the original OCR input is cleaner to use for the bounding box overlays than the output of the layout model.

I'd like to visualize the label outputs from LiLT with the original bounding boxes from OCR. Is there a good way to do visualizations even though the original bounding boxes and words won't match up with the tokenized boxes and words?

I've tried removing all of the additional bounding boxes created here

for word, box in zip(words, boxes):
    box = normalize_bbox(box, width, height)
    n_word_tokens = len(tokenizer.tokenize(word))
    bbox.extend([box] * n_word_tokens)

But I still get more labels and boxes than the original OCR input.

Thanks

mellahysf commented 8 months ago

@eschaffn how did you resolve the issue IndexError: Thebboxcoordinate values should be within 0-1000 range. ? Could you provide the whole inference working code please? Thanks. @NielsRogge any update on that please?

philmas commented 3 weeks ago

@eschaffn how did you resolve the issue IndexError: Thebboxcoordinate values should be within 0-1000 range. ? Could you provide the whole inference working code please? Thanks. @NielsRogge any update on that please?

I am running into the same issue, not sure how to resolve this.

eschaffn commented 3 weeks ago

@eschaffn how did you resolve the issue IndexError: Thebboxcoordinate values should be within 0-1000 range. ? Could you provide the whole inference working code please? Thanks. @NielsRogge any update on that please?

I am running into the same issue, not sure how to resolve this.

@eschaffn how did you resolve the issue IndexError: Thebboxcoordinate values should be within 0-1000 range. ? Could you provide the whole inference working code please? Thanks. @NielsRogge any update on that please?

Sorry guys I don't remember, this was a while ago and I don't even think I ended up using the model, or saving this code anywhere.

NielsRogge commented 3 weeks ago

You just gotta make sure to normalize your bounding boxes as the model only knows embeddings for boxes between 0 and 1000. See here: https://huggingface.co/docs/transformers/en/model_doc/layoutlm#usage-tips (it's equivalent for LiLT)