microsoft / unilm

Large-scale Self-supervised Pre-training Across Tasks, Languages, and Modalities
https://aka.ms/GeneralAI
MIT License
20.2k stars 2.55k forks source link

Unclear input data structure of layoutreader #464

Open 7fantasysz opened 3 years ago

7fantasysz commented 3 years ago

Describe Model I am using (layoutreader): How should I prepare my input data to use script decode_seq2seq.py and the pretrained model? It's not clearly stated anywhere and please help.

HYPJUDY commented 3 years ago

Please follow the instructions. Specifically, you can download the data refer to step 1 (wget https://layoutlm.blob.core.windows.net/readingbank/dataset/ReadingBank.zip), extract the data (unzip ReadingBank.zip and you should get the test data ReadingBank/test), and speficy the --input_folder of decoding script in step 3. Then you can download the pretrained model refer to step 2 (wget https://layoutlm.blob.core.windows.net/readingbank/model/layoutreader-base-readingbank.zip), extract the model (unzip layoutreader-base-readingbank.zip), and speficy the --model_path of the decoding script in step 3.

SimeonZhang commented 3 years ago

@HYPJUDY It is reported that the LayoutReader was evaluated with several different settings such as text only, layout only or different shuffle rate while training. Could you release your model zoo with different settings?

zlwang-cs commented 3 years ago

@SimeonZhang Hi, it is easy to run LayoutReader in text only settings by using the corresponding model_name and model_name_or_path in args. You can also run the layout only settings with the layoutlm_only_layout flag in args. As for different shuffle rates, you can try the sentence_shuffle_rate in args.

ManuelFay commented 3 years ago

To help you out with custom data, here is a minimal working example - not optimized at all - (you'll have to switch things around a bit.

Note 1: the boxes correspond to the normalized bounding boxes one would use for layoutlm, laypoutlmv2 etc... Note 2: For pages with more than 511 words, a more complex approach is needed (sliding window would be the safest bet ?). Curious as to how the authors dealt with this ?

import logging
from typing import Tuple, List
import os
import numpy as np
import torch
from transformers import AutoTokenizer

import s2s_ft.s2s_loader as seq2seq_loader
from s2s_ft.modeling_decoding import LayoutlmForSeq2SeqDecoder, BertConfig
from s2s_ft.s2s_loader import Preprocess4Seq2seqDecoder

class TextOrderer:
    def __init__(self, model_path: str):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # pylint: disable=no-member
        self.tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlmv2-base-uncased")
        config_file = os.path.join(model_path, "config.json")

        self.config = BertConfig.from_json_file(config_file, layoutlm_only_layout_flag=True)
        self.model = LayoutlmForSeq2SeqDecoder.from_pretrained(model_path, config=self.config).to(self.device)
        self.max_len = 511
        self.preprocessor = Preprocess4Seq2seqDecoder(
            list(self.tokenizer.vocab.keys()),
            self.tokenizer.convert_tokens_to_ids,
            1024,
            max_tgt_length=self.max_len,
            layout_flag=True
        )

    def __call__(self, *args, **kwargs):
        return self.reconstruct(*args, **kwargs)

    def forward(self, words, boxes) -> List[int]:
        """
        :param words: Word list [sorted in top-down / left-right fashion for best performance)
        :param boxes: Normalized bounding box list (layoutlm format)
        :return: Re-ordered index list
        """
        assert len(words) == len(boxes)

        instance = [[x[0], *x[1]] for x in list(zip(words, boxes))], len(boxes)
        instances = [self.preprocessor(instance)]
        with torch.no_grad():
            batch = seq2seq_loader.batch_list_to_batch_tensors(
                instances)
            batch = [
                t.to(self.device) if t is not None else None for t in batch]
            input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch

            traces = self.model(input_ids, token_type_ids,
                                position_ids, input_mask, task_idx=task_idx, mask_qkv=mask_qkv)
            output_ids = traces.squeeze().tolist()
            output_ids = list(np.array(output_ids) - 1)
            return output_ids

    def reconstruct(self, words: List[str], boxes: List[List[int]]) -> Tuple[List[str], List[List[int]]]:

        assert len(words) == len(boxes)

        if len(words) > self.max_len:
            logging.warning(
                f"Page contains {len(words)} words. Exceeds the {self.max_len} limit and will not be reordered.")
            return words, boxes

        try:
            idx = self.forward(words, boxes)
            processed_idx = list(dict.fromkeys(idx))
            if len(processed_idx) != len(words):
                processed_idx = [idx for idx in processed_idx if idx < len(words)]
                unused_idx = sorted(list(set(range(len(words))) - set(processed_idx[:len(words)])))
                logging.info(
                    f"There is {len(words)} words but only {len(processed_idx)} indexes. "
                    f"Unmatched indexes: {unused_idx}")
                processed_idx.extend(unused_idx)
                logging.info(f"There is now {len(words)} wordsand {len(processed_idx)} indexes.")
                assert len(processed_idx) == len(words)

            words = list(np.array(words)[processed_idx])
            boxes = [elem.tolist() for elem in np.array(boxes)[processed_idx]]
            return words, boxes

        except Exception as exception: # pylint: disable=broad-except
            logging.warning(exception)
            return words, boxes
zlwang-cs commented 3 years ago

@ManuelFay Thank you for the great job! As for your question, we run LayoutReader on ReadingBank dataset and we filter out the pages with more than 511 tokens. For other datasets, I agree the sliding window is a feasible solution.

myh12138 commented 2 years ago

@SimeonZhang Hi, it is easy to run LayoutReader in text only settings by using the corresponding model_name and model_name_or_path in args. You can also run the layout only settings with the layoutlm_only_layout flag in args. As for different shuffle rates, you can try the sentence_shuffle_rate in args.

hi,when I change input_folder to my data and make layoutlm_only_layout ,the progress still shows the original data? Can you give me some suggestions? Thanks

HoomanKhosravi commented 2 years ago

@HYPJUDY @ManuelFay Thank you for sharing your code. I was wondering what's your approach with sorting in top-down / left-right? do you look at bottom and left of bboxes? do you know if reading bank data is sorted in this manner?

ManuelFay commented 2 years ago

From the LayoutReader paper it is mentionned it is best to do so. It's then just a matter of sorting the list with a double key in Python, take any point from the boxes as long as it's consistent. Note that some libraries (pdfplumber with textflow, pytesseract) include a text orderer usually better than the top-down-left-right heuristic that might be easier to use ! Cheers !

HoomanKhosravi commented 2 years ago

From the LayoutReader paper it is mentionned it is best to do so. It's then just a matter of sorting the list with a double key in Python, take any point from the boxes as long as it's consistent. Note that some libraries (pdfplumber with textflow, pytesseract) include a text orderer usually better than the top-down-left-right heuristic that might be easier to use ! Cheers !

Thanks!

animebing commented 2 years ago

From the LayoutReader paper it is mentionned it is best to do so. It's then just a matter of sorting the list with a double key in Python, take any point from the boxes as long as it's consistent. Note that some libraries (pdfplumber with textflow, pytesseract) include a text orderer usually better than the top-down-left-right heuristic that might be easier to use ! Cheers !

@ManuelFay I use bboxes = sorted(bboxes, lambda x:[x[0], x[1]]) to get left-right-top-down order and adjust texts accordingly, but I find the result is much worse than the order from Tesseract, which confuses me a lot, can you give me some suggestions about it, thanks

tengerye commented 2 years ago

Hi, maybe there is something wrong in function load_and_cache_line_order_examples as well?

It tries to load a single file only. However, it is expected to load several files from a folder, as in function load_and_cache_layoutlm_examples.

zlwang-cs commented 2 years ago

Actually, the load_and_cache_line_order_examples is deprecated. You can reproduce a similar function if you need to conduct such experiments.

alejandrojcastaneira commented 2 years ago

Hi @ManuelFay thanks for sharing the code, I tried and works really well!, one question, do you have any suggestions on how to optimize the inference time, I notice it takes like 7 seconds for a single doc.

ManuelFay commented 2 years ago

Hi @ManuelFay thanks for sharing the code, I tried and works really well!, one question, do you have any suggestions on how to optimize the inference time, I notice it takes like 7 seconds for a single doc.

It is really long... The authors recommend a few things in the repo (cuda libs to go faster and do ops in a quantized way). Apart from that, simplest is probably to use heuristics to split the page into "safe" subpages and then reconstruct them one at a time since speed is square proportional to sequence length in transformers. It does kind of lose most of it's purpose though... Maybe the authors are working on a more efficient model or have tips

logan-markewich commented 1 year ago

Thank you for the code @ManuelFay ! Unfortunately, the results are not great with tesseract as input. At least I could test it though :)

Jesteinbe commented 1 year ago

This may be a naive question but are the layout-only results reported in the paper obtained by running inference with the model trained on layout+text but without any token embeddings or is it a different model trained on layout-only features? I thought it was the latter but the above posts make me think it might be the former.

Also, based on what I'm seeing should we only be using the --layoutlm_only_layout but not --layout_only_dataset? If the latter is used then it forces the use of the deprecated load_and_cache_line_order_examples. However, if we don't specify --layout_only_dataset then it loads all the text-based data. Is this the right thing to do?

zlwang-cs commented 1 year ago

Hi @Jesteinbe, sorry for the confusion. Let me specify the meaning of each argument.

In our paper, we have two sets of experiments: 1) the token-level reading order (Table 2, 3, 4); 2) 2) the line-level reading order (Table 5, 6).

In the first setting (token level), we have two changes to the standard setting: the shuffle rate and w/ or w/o layout. I assume the shuffle rate is easy to understand from the code. As for the "w/ or w/o layout", it is realized by --layoutlm_only_layout (standing for: the layoutlm model is used but only layout features are considered). This argument will remove the word embeddings in the model so the model can only learn from the layout information. (See Line 203 in s2s_ft/modeling.py).

In the second setting (line level), the text lines are extracted by the OCR engines. Since each bounding box now corresponds to a line instead of a token, it is not straightforward to model it in the standard setting where the bounding box and word are matched to each other. So our solution is that we use placeholders as text input for each text line (see Line 370 in s2s_ft/utils.py). This is triggered by --layout_only_dataset standing for a dataset that only has layout features.

I think it is not hard to tell the difference between load_and_cache_line_order_examples and load_and_cache_layoutlm_examples from the code. And these two functions can also be used in other cases depending on the exact using scenarios.

If you have any other questions, please let us know :)

Jesteinbe commented 1 year ago

Thanks @zlwang-cs! I'll play around with things a bit more and let you know if I have any other questions.