Table Transformer (TATR) is a deep learning model for extracting tables from unstructured documents (PDFs and images). This is also the official repository for the PubTables-1M dataset and GriTS evaluation metric.
MIT License
2.01k
stars
231
forks
source link
Many short rows with financial table structure model #178
I am finding that when using the latest TATR-v1.1-Fin model that in the initial TSR phase the model detects many very thin rows in the middle of a table, even when the table seems simple and is from the original training set. However, the TATR-v1.1-All model is working fine on the same table. I'm not sure if I'm doing something incorrect here in preprocessing the data?
Screenshots below demonstrate the problem on the image AAL_2014_page_192_table_0.jpg from the FinTabNet.c training set. Example code below can be used to re-create the issue, just toggle between model_type = 'all' and model_type = 'fin' at the top of the code
Original table:
Fin model:
All model:
from PIL import Image, ImageDraw
from pathlib import Path
import torch
from transformers import TableTransformerForObjectDetection
from torchvision import transforms
model_type = 'fin' # can be 'fin' or 'all'
PATH = Path(__file__).parent.resolve()
cropped_table = Image.open(PATH / "data/AAL_2014_page_192_table_0.jpg").convert("RGB")
#############
# Setup and utility functions
#############
device = "cuda" if torch.cuda.is_available() else "cpu"
class MaxResize(object):
def __init__(self, max_size=800):
self.max_size = max_size
def __call__(self, image):
width, height = image.size
current_max_size = max(width, height)
scale = self.max_size / current_max_size
resized_image = image.resize((int(round(scale*width)), int(round(scale*height))))
return resized_image
structure_transform = transforms.Compose([
MaxResize(1000),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
structure_class_thresholds = {
"table": 0.5,
"table column": 0.5,
"table row": 0.5,
"table column header": 0.5,
"table projected row header": 0.5,
"table spanning cell": 0.5,
"no object": 10
}
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b
def outputs_to_objects(outputs, img_size, class_idx2name):
m = outputs['logits'].softmax(-1).max(-1)
pred_labels = list(m.indices.detach().cpu().numpy())[0]
pred_scores = list(m.values.detach().cpu().numpy())[0]
pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
objects = []
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
class_label = class_idx2name[int(label)]
if not class_label == 'no object':
objects.append({'label': class_label, 'score': float(score),
'bbox': [float(elem) for elem in bbox]})
return objects
def draw_bboxes(bboxes, page_image, color='red'):
page_image = page_image.copy()
draw = ImageDraw.Draw(page_image)
for bbox in bboxes:
draw.rectangle(bbox, outline=color)
return page_image
structure_models = {
'all': TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all"),
'fin': TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-fin")
}
for k in structure_models:
structure_models[k] = structure_models[k].to(device)
structure_id2label = structure_models['all'].config.id2label
structure_id2label[len(structure_id2label)] = "no object"
#############
# Running the model
#############
pixel_values = structure_transform(cropped_table).unsqueeze(0)
pixel_values = pixel_values.to(device)
with torch.no_grad():
structure_outputs = structure_models[model_type](pixel_values)
structure_outputs = outputs_to_objects(structure_outputs, cropped_table.size, structure_id2label)
draw_bboxes([x['bbox'] for x in structure_outputs], cropped_table)
I am finding that when using the latest
TATR-v1.1-Fin
model that in the initial TSR phase the model detects many very thin rows in the middle of a table, even when the table seems simple and is from the original training set. However, theTATR-v1.1-All
model is working fine on the same table. I'm not sure if I'm doing something incorrect here in preprocessing the data?Screenshots below demonstrate the problem on the image
AAL_2014_page_192_table_0.jpg
from the FinTabNet.c training set. Example code below can be used to re-create the issue, just toggle betweenmodel_type = 'all'
andmodel_type = 'fin'
at the top of the codeOriginal table:![image](https://github.com/microsoft/table-transformer/assets/168236/81a74da3-f51a-4df8-be46-b071b98be5d2)
Fin model:![image](https://github.com/microsoft/table-transformer/assets/168236/bdcdc751-be8e-4662-a6a1-cbde4daf982a)
All model:![image](https://github.com/microsoft/table-transformer/assets/168236/607dd112-819d-4f3c-8690-a3d09fbce018)