huggingface / pixparse

Pixel Parsing. A reproduction of OCR-free end-to-end document understanding models with open data
11 stars 3 forks source link

unified abstractions for collate functions #26

Open molbap opened 11 months ago

molbap commented 11 months ago

In PR #24 , collate_fn is

And same for task_finetune_docvqa and the 3 eval tasks associated.

Suggestion: Make a template Collator class with inherited cases per class, instantiate collate_fn from these

Like

class CollateTask:
    def __init__(self, tokenizer, image_preprocess, start_token):
        self.tokenizer = tokenizer
        self.image_preprocess = image_preprocess
        self.start_token = start_token

    def tokenizer_fn(self, x, max_length=512):
        return self.tokenizer(
            x,
            add_special_tokens=False,
            return_tensors="pt",
            max_length=max_length,
            padding="max_length",
            truncation=True,
        ).input_ids[0]

    def __call__(self, batch):
        raise NotImplementedError("This method should be overridden by child classes")

And for RVLCDIP

class CollateRVLCDIP(CollateTask):
    def __init__(self, tokenizer, image_preprocess, start_token, label_int2str):
        super().__init__(tokenizer, image_preprocess, start_token)
        self.label_int2str = label_int2str

    def __call__(self, batch):
       ....

And so on for CORD and DocVQA and others so that the collate_fn simply returns the instantiation of the collator class