Open molbap opened 11 months ago
In PR #24 , collate_fn is
not defined in task_cruller_pretrain.py
task_cruller_pretrain.py
in task_finetune_rvlcdip.py, defined as
task_finetune_rvlcdip.py
def collate_fn(self, batch): return CollateRVLCDIP( self.tokenizer, self.image_preprocess_train, self.task_start_token, self.label_int2str, )
where
class CollateRVLCDIP: """ basic collator for PIL images, as returned by rvlcdip dataloader (among others) """ def __init__( self, tokenizer, image_preprocess, start_token: str, label_int2str: dict, ): self.tokenizer = tokenizer self.tokenizer_fn = lambda x: self.tokenizer( x, add_special_tokens=False, return_tensors='pt', max_length=5, padding='max_length', truncation=True).input_ids[0] self.image_preprocess = image_preprocess self.start_token = start_token self.int2str = label_int2str def __call__(self, batch): images = [item["image"] for item in batch] labels = [item["label"] for item in batch] labels_tokens = [ self.tokenizer_fn(self.start_token + "<" + self.int2str[label] + "/>" + self.tokenizer.eos_token) for label in labels ] images = torch.stack([self.image_preprocess(img) for img in images]) labels = torch.stack(labels_tokens) targets = torch.stack([text_input_to_target(text, self.tokenizer) for text in labels]) labels = labels[:, :-1] targets = targets[:, 1:] return {"image": images, "label": labels, "text_target": targets}
In task_finetune_cord.py, defined directly as
task_finetune_cord.py
def collate_fn(self, batch): """ basic collator for PIL images, as returned by rvlcdip dataloader (among others) """ tokenizer_fn = lambda x: self.tokenizer( x, # FIXME move this batcher/tokenizer elsewhere add_special_tokens=False, return_tensors="pt", max_length=512, padding="max_length", truncation=True, ).input_ids[0] images = [item["image"] for item in batch] raw_texts = [literal_eval(item["ground_truth"])["gt_parse"] for item in batch] inputs_to_stack = [] for text in raw_texts: tokens_from_json, _ = json2token(text, self.tokenizer.all_special_tokens, sort_json_key=False) inputs_to_stack.append(tokenizer_fn( self.task_start_token #+ self.tokenizer.bos_token + tokens_from_json + self.tokenizer.eos_token )) text_inputs = torch.stack( inputs_to_stack ) targets = torch.stack([self.text_input_to_target(text) for text in text_inputs]) transform = self.image_preprocess_train images = torch.stack([transform(img) for img in images]) text_inputs = text_inputs[:, :-1] targets = targets[:, 1:] return { "image": images, "label": text_inputs, "text_target": targets, }
And same for task_finetune_docvqa and the 3 eval tasks associated.
task_finetune_docvqa
Suggestion: Make a template Collator class with inherited cases per class, instantiate collate_fn from these
Like
class CollateTask: def __init__(self, tokenizer, image_preprocess, start_token): self.tokenizer = tokenizer self.image_preprocess = image_preprocess self.start_token = start_token def tokenizer_fn(self, x, max_length=512): return self.tokenizer( x, add_special_tokens=False, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True, ).input_ids[0] def __call__(self, batch): raise NotImplementedError("This method should be overridden by child classes")
And for RVLCDIP
class CollateRVLCDIP(CollateTask): def __init__(self, tokenizer, image_preprocess, start_token, label_int2str): super().__init__(tokenizer, image_preprocess, start_token) self.label_int2str = label_int2str def __call__(self, batch): ....
And so on for CORD and DocVQA and others so that the collate_fn simply returns the instantiation of the collator class
In PR #24 , collate_fn is
not defined in
task_cruller_pretrain.py
in
task_finetune_rvlcdip.py
, defined aswhere
In
task_finetune_cord.py
, defined directly asAnd same for
task_finetune_docvqa
and the 3 eval tasks associated.Suggestion: Make a template Collator class with inherited cases per class, instantiate collate_fn from these
Like
And for RVLCDIP
And so on for CORD and DocVQA and others so that the collate_fn simply returns the instantiation of the collator class