poloclub / unitable

UniTable: Towards a Unified Table Foundation Model
https://arxiv.org/abs/2403.04822
MIT License
276 stars 15 forks source link

Limited Number of Cells #2

Closed Omar280x closed 3 months ago

Omar280x commented 3 months ago

Hello, thanks for your great work. I've tried your notebook on small tables and it works great, but when it comes to large tables (ex: 10 columns, 30 rows) the model only detects 256 cells. Is there a specific parameter I can modify so that the model can extract more cells or the length is capped at 256?

Filimoa commented 3 months ago

I'm not a 100% sure on this but the PositionEmbedding class creates an embedding layer with a fixed size (max_seq_len). If you want to handle longer sequences, you would need a larger positional embedding matrix which would require retraining the model.

Omar280x commented 3 months ago

ok thanks for your answer

ShengYun-Peng commented 3 months ago

Thanks @Filimoa! @Omar280x: The max_seq_len is 1024 for bbox detection, thus 256 bboxes in total since each bbox has four coordinates. In our experiments, we found that sufficient for most of our tables and still trainable within our GPU memory (80G). An easy hack is slicing a large table horizontally, querying UniTable with multiple small tables, and concatenating all HTML sequences in order.

Omar280x commented 3 months ago

@ShengYun-Peng that's a great idea, I've also considered slicing, but the problem is it will be an automated process on a huge data of tables, but I'll try a workaround. Thanks for your reply and help!

lerndeep commented 1 month ago

@Omar280x @ShengYun-Peng I also face the same problem not able to detect all cell if more than 256 cells in table. Could you please let me know how can I train and inference if the number of cells inside table is more than 1000?

lerndeep commented 1 month ago

@Omar280x have you tried slicing and combining techniques? could you please let me know the code for that?

yumikim381 commented 3 weeks ago

Same! I would be very interested too

yumikim381 commented 3 weeks ago

I write the script for horzontally cropping the tables and adding the extra bounding boxes `countcells = 0 for elem in pred_html : if elem == '[]': countcells+=1

        #275
        print(countcells)
        if countcells > len(pred_bbox): 
            #TODO Extra processing for big tables 

            #Find the last incomplete row and its ymax coordinate 

            # Last bbox's ymax gives us coordinate of where the cutted off row starts 
            #IMPORTANT : pred_bbox is xmin, ymax, xmax, ymin
            cut_off = pred_bbox[-1][2]

            width = images.size[0]
            height = images.size[1]

            #IMPORTANT : crop takes in (xmin, ymax, xmax, ymin) coordintes !!!
            bbox = (0, cut_off, width, height)
            # Crop the image to the specified bounding box
            cropped_image = images.crop(bbox)
            cropped_image.save("./res/cropped_image_for_extra_bbox_det.png")
            image_tensor = self.image_to_tensor(cropped_image, (448, 448))
            pred_bbox_extra = self.autoregressive_decode(
                model=modelB,
                image=image_tensor,
                prefix=[vocabB.token_to_id("[bbox]")],
                max_decode_len=1024,
                eos_id=vocabB.token_to_id("<eos>"),
                token_whitelist=[vocabB.token_to_id(i) for i in VALID_BBOX_TOKEN[: 449]],
                token_blacklist = None
            )
            # Convert token id to token text
            pred_bbox_extra = pred_bbox_extra.detach().cpu().numpy()[0]
            pred_bbox_extra = vocabB.decode(pred_bbox_extra, skip_special_tokens=False)
            pred_bbox_extra = bbox_str_to_token_list(pred_bbox_extra)
            numberOrCellsToAdd = countcells-len(pred_bbox)
            pred_bbox_extra = pred_bbox_extra[-numberOrCellsToAdd:]
            pred_bbox_extra = self.rescale_bbox(pred_bbox_extra, src=(448, 448), tgt=cropped_image.size)
            #This resulted in table_bbox_test_extra_3.png
            #pred_bbox_extra = [[i[0], i[1]+cut_off, i[2], i[3]+cut_off] for i in pred_bbox_extra]
            pred_bbox_extra = [[i[0], i[1]+cut_off, i[2], i[3]+cut_off] for i in pred_bbox_extra]

            pred_bbox = pred_bbox + pred_bbox_extra

            fig, ax = plt.subplots(figsize=(12, 10))
            for i in pred_bbox:
                #i is xmin, ymin, xmax, ymax based on the function usage
                rect = patches.Rectangle(i[:2], i[2] - i[0], i[3] - i[1], linewidth=1, edgecolor='r', facecolor='none')
                ax.add_patch(rect)
            ax.set_axis_off()
            ax.imshow(images[i])
            fig.savefig('table_drawn_bbox_with_extra.png', bbox_inches='tight', dpi=300)`