NielsRogge / Transformers-Tutorials

This repository contains demos I made with the Transformers library by HuggingFace.
MIT License
8.45k stars 1.32k forks source link

Layoutlmv3 issue with inferencing bounding box is not plotting correctly #392

Open rajasekarkrish opened 4 months ago

rajasekarkrish commented 4 months ago

Layoutlmv3 issue with inferencing bounding box is not plotting correctly from transformers import AutoModelForTokenClassification from datasets import load_dataset import torch from transformers import AutoProcessor import matplotlib.pyplot as plt from updatetrain import id2label import matplotlib.patches as patches

model = AutoModelForTokenClassification.from_pretrained("/new_dataset/new_layoutlmv3/checkpoint-3000")

dataset = load_dataset(r"/new_layoutlmv3_dataset/new_dataset/updateddataset.py")

processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)

example = dataset["test"][1] print(example["image"]) print(example.keys())

image = example["image"] words = example["words"] boxes = example["bboxes"] word_labels = example["ner_tags"]

encoding = processor(image, words, boxes=boxes, word_labels=word_labels, truncation=True , stride =128, return_tensors="pt")

for k,v in encoding.items(): print(k,v.shape)

with torch.no_grad(): outputs = model(**encoding)

logits = outputs.logits print(logits.shape)

predictions = logits.argmax(-1).squeeze().tolist() print(predictions)

labels = encoding.labels.squeeze().tolist() print(labels)

print("printing five labels",labels[:5]) # Print the first 5 labels def unnormalize_box(bbox, width, height): return [ width (bbox[0]/1000), height (bbox[1]/1000), width (bbox[2]/1000), height (bbox[3]/1000), ]

token_boxes = encoding.bbox.squeeze().tolist() width, height = image.size print("Image size:", width, "x", height)

print("printing boxes",boxes[:5] )

print("printing token boxes",token_boxes[:5]) # Print the first 5 bounding boxes

true_predictions = [model.config.id2label[pred] for pred, label in zip(predictions, labels) if label != - 100] true_labels = [model.config.id2label[label] for prediction, label in zip(predictions, labels) if label != -100] true_boxes = [unnormalize_box(box, width, height) for box, label in zip(token_boxes, labels) if label != -100]

print("printing true Predictions:", true_predictions[:5]) # Print the first 5 predictions print("printing true Labels:", true_labels[:5]) # Print the first 5 labels

from PIL import ImageDraw, ImageFont

draw = ImageDraw.Draw(image)

font = ImageFont.load_default()

def iob_to_label(label): label = label[2:] if not label: return 'other' return label

label2color = { 'relevant': 'red', 'se_tax_header': 'green', 'se_tax_due_header': 'green', 'ar_header': 'green', 'se_tax_total': 'blue', 'se_tax_due_total': 'blue', 'se_tax': 'yellow', 'se_tax_due': 'orange', 'ar': 'orange', }

print(label2color) print(model.config.id2label)

for i, (prediction, label) in enumerate(zip(predictions, labels)): if label == -100: continue # Skip the padding tokens or any token that should be ignored predicted_label = model.config.id2label.get(prediction, "Label not found") actual_label = id2label.get(label, "Label not found") print(f"Token {i}: Predicted - {predicted_label}, Actual - {actual_label}")

for prediction, box in zip(true_predictions, true_boxes): predicted_label = iob_to_label(prediction).lower() if predicted_label in label2color: draw.rectangle(box, outline=label2color[predicted_label]) draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font) else: print(f"Label {predicted_label} not in label2color dictionary.")

plt.imshow(image) plt.show()

bounding_box_not_ploted_accordingly_layoutlmv3

Dataset preparation code import json import os import numpy as np from PIL import Image import datasets

import torch

logger = datasets.logging.get_logger(name)

def normalize_bbox(bbox, size): return [ int(1000 bbox[0] / size[0]), int(1000 bbox[1] / size[1]), int(1000 bbox[2] / size[0]), int(1000 bbox[3] / size[1]), ]

def load_image(image_path): image = Image.open(image_path).convert("RGB") w, h = image.size return image, (w, h)

class CustomDatasetConfig(datasets.BuilderConfig): """BuilderConfig for CustomDataset""" def init(self, kwargs): """BuilderConfig for CustomDataset. Args: kwargs: keyword arguments forwarded to super. """ super(CustomDatasetConfig, self).init(**kwargs)

class CustomDataset(datasets.GeneratorBasedBuilder): """Custom dataset for document understanding."""

BUILDER_CONFIGS = [
    CustomDatasetConfig(name="custom_dataset", version=datasets.Version("1.0.0"), description="Custom dataset"),
]

def _info(self):
    return datasets.DatasetInfo(
        features=datasets.Features(
            {
                "id": datasets.Value("string"),
                "words": datasets.Sequence(datasets.Value("string")),
                "bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
                "ner_tags": datasets.Sequence(
                    datasets.features.ClassLabel(
                        names=[
                            'irrelevant',
                            'base_tax_header',
                            'base_tax_due_header',
                            'year_header',
                            'base_tax_total',
                            'base_tax_due_total',
                            'base_tax',
                            'base_tax_due',
                            'year',
                            # Add more labels as per your requirement
                        ]
                    )
                ),
                "image": datasets.features.Image(),
                "image_path": datasets.Value("string"),
            }
        ),
        supervised_keys=None,
    )

def _split_generators(self, dl_manager):
    """Returns SplitGenerators."""
    # Assuming the data is already downloaded/extracted and available in a specific directory
    data_dir = '/new_layoutlmv3_dataset/new_dataset/data/'
    return [
        datasets.SplitGenerator(
            name=datasets.Split.TRAIN, gen_kwargs={"filepath": os.path.join(data_dir, "train.json")},
        ),
        datasets.SplitGenerator(
            name=datasets.Split.TEST, gen_kwargs={"filepath": os.path.join(data_dir, "test.json")},
        ),
    ]

def _generate_examples(self, filepath):
    logger.info("⏳ Generating examples from = %s", filepath)
    with open(filepath, "r", encoding="utf8") as f:
        data = json.load(f)

    # Define the base directory for your images
    base_dir = r"D:\new_layoutlmv3_dataset\new_dataset"  # Update this path to your base directory

    for guid, item in enumerate(data):
        # Check if 'file_name' exists and correct the path
        if 'file_name' not in item:
            logger.warning(f"Skipping entry {guid} due to missing 'file_name'")
            continue
        # Correct the slash direction and prepend the base directory to the file name
        image_relative_path = item['file_name'].replace('../', '').replace('/', '\\')
        image_path = os.path.join(base_dir, image_relative_path)

        image, size = load_image(image_path)
        words, bboxes, ner_tags = [], [], []

        for annotation in item["annotations"]:
            words.append(annotation["text"])
            normalized_bbox = normalize_bbox(annotation["box"], size)
            bboxes.append(normalized_bbox)
            ner_tags.append(annotation["label"])

        yield guid, {
            "id": str(guid),
            "words": words,
            "bboxes": bboxes,
            "ner_tags": ner_tags,
            "image": image,  # Adjusting according to the expected format
            "image_path": image_path  # Keep this as 'image_path' for consistency in your dataset features
        }