microsoft / table-transformer

Table Transformer (TATR) is a deep learning model for extracting tables from unstructured documents (PDFs and images). This is also the official repository for the PubTables-1M dataset and GriTS evaluation metric.
MIT License
2.01k stars 231 forks source link

Many short rows with financial table structure model #178

Open abielr opened 2 months ago

abielr commented 2 months ago

I am finding that when using the latest TATR-v1.1-Fin model that in the initial TSR phase the model detects many very thin rows in the middle of a table, even when the table seems simple and is from the original training set. However, the TATR-v1.1-All model is working fine on the same table. I'm not sure if I'm doing something incorrect here in preprocessing the data?

Screenshots below demonstrate the problem on the image AAL_2014_page_192_table_0.jpg from the FinTabNet.c training set. Example code below can be used to re-create the issue, just toggle between model_type = 'all' and model_type = 'fin' at the top of the code

Original table: image

Fin model: image

All model: image

from PIL import Image, ImageDraw
from pathlib import Path
import torch
from transformers import TableTransformerForObjectDetection
from torchvision import transforms

model_type = 'fin' # can be 'fin' or 'all'
PATH = Path(__file__).parent.resolve()
cropped_table = Image.open(PATH / "data/AAL_2014_page_192_table_0.jpg").convert("RGB")

#############
# Setup and utility functions
#############

device = "cuda" if torch.cuda.is_available() else "cpu"

class MaxResize(object):
    def __init__(self, max_size=800):
        self.max_size = max_size

    def __call__(self, image):
        width, height = image.size
        current_max_size = max(width, height)
        scale = self.max_size / current_max_size
        resized_image = image.resize((int(round(scale*width)), int(round(scale*height))))

        return resized_image

structure_transform = transforms.Compose([
    MaxResize(1000),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

structure_class_thresholds = {
    "table": 0.5,
    "table column": 0.5,
    "table row": 0.5,
    "table column header": 0.5,
    "table projected row header": 0.5,
    "table spanning cell": 0.5,
    "no object": 10
}

def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b

def outputs_to_objects(outputs, img_size, class_idx2name):
    m = outputs['logits'].softmax(-1).max(-1)
    pred_labels = list(m.indices.detach().cpu().numpy())[0]
    pred_scores = list(m.values.detach().cpu().numpy())[0]
    pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
    pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]

    objects = []
    for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
        class_label = class_idx2name[int(label)]
        if not class_label == 'no object':
            objects.append({'label': class_label, 'score': float(score),
                            'bbox': [float(elem) for elem in bbox]})

    return objects

def draw_bboxes(bboxes, page_image, color='red'):
    page_image = page_image.copy()
    draw = ImageDraw.Draw(page_image)

    for bbox in bboxes:
        draw.rectangle(bbox, outline=color)

    return page_image

structure_models = {
    'all': TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all"),
    'fin': TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-fin")
}
for k in structure_models:
    structure_models[k] = structure_models[k].to(device)
structure_id2label = structure_models['all'].config.id2label
structure_id2label[len(structure_id2label)] = "no object"

#############
# Running the model
#############

pixel_values = structure_transform(cropped_table).unsqueeze(0)
pixel_values = pixel_values.to(device)

with torch.no_grad():
  structure_outputs = structure_models[model_type](pixel_values)

structure_outputs = outputs_to_objects(structure_outputs, cropped_table.size, structure_id2label)

draw_bboxes([x['bbox'] for x in structure_outputs], cropped_table)
Oleksii94 commented 1 month ago

@bsmock @msftgits @themanojkumar @NielsRogge

NielsRogge commented 1 month ago

That's an interesting, weird issue. Could NMS (non-maximum suppression) help here?