clovaai / donut

Official Implementation of OCR-free Document Understanding Transformer (Donut) and Synthetic Document Generator (SynthDoG), ECCV 2022
https://arxiv.org/abs/2111.15664
MIT License
5.69k stars 461 forks source link

Working very well on training but doing pretty poorly at inference #65

Open WaterKnight1998 opened 1 year ago

WaterKnight1998 commented 1 year ago

Hi,

Thanks you for publishing this model :)

I want to use this model for Document Parsing. I have annotations for two kinds of pdf, 20 images per type.

At training it achieves very good edit distances 0.08, 0.0, 0.1, 0.13... But when I try it at inference the edit distance is very bad 0.8

In addition in those predictions I am seeing that in several predictions it outputs: {"text_sequence": "word1 word1 word1 ... word1"} where word1 is a random word repeated all over the place

Thanks in advance

satheeshkatipomu commented 1 year ago

what is the prompt you are using while inference? most probably it is messing up the json output. If you have fine-tuned base model for your task, try using <s_REPLACE_YOUR_DATASET_NAME> as prompt at inference time.

WaterKnight1998 commented 1 year ago

Hi @satheeshkatipomu , thank you for answering.

I was following this training tutorial.

If you check his lightning module:

def validation_step(self, batch, batch_idx, dataset_idx=0):
        pixel_values, labels, answers = batch
        batch_size = pixel_values.shape[0]
        # we feed the prompt to the model
        decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)

        outputs = self.model.generate(pixel_values,
                                   decoder_input_ids=decoder_input_ids,
                                   max_length=max_length,
                                   early_stopping=True,
                                   pad_token_id=self.processor.tokenizer.pad_token_id,
                                   eos_token_id=self.processor.tokenizer.eos_token_id,
                                   use_cache=True,
                                   num_beams=1,
                                   bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
                                   return_dict_in_generate=True,)

        predictions = []
        for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
            seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
            seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
            predictions.append(seq)

        scores = list()
        for pred, answer in zip(predictions, answers):
            pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
            # NOT NEEDED ANYMORE
            # answer = re.sub(r"<.*?>", "", answer, count=1)
            answer = answer.replace(self.processor.tokenizer.eos_token, "")
            scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))

            if self.config.get("verbose", False) and len(scores) == 1:
                print(f"Prediction: {pred}")
                print(f"    Answer: {answer}")
                print(f" Normed ED: {scores[0]}")

        return scores

He uses decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device) as decoder_input_ids instead of

    task_prompt = "<s_cord-v2>"
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids

If at inference time I use: decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device) the performance improves

satheeshkatipomu commented 1 year ago

did you fine-tune the model on your custom dataset? or trying to use the off the shelf cord-v2 model?

WaterKnight1998 commented 1 year ago

Yes, I finetuned the model with my custom dataset but I used that code @satheeshkatipomu

satheeshkatipomu commented 1 year ago

I see some changes compared to the code in this repo. but still I suggest checking your prompt at inference time if you think edit distance very low during training but very high while inference AND assuming your test images are similar to training data.

For example here in the fine-tuning notebook they hard corded start and end token as <s_cord-v2>. prompt_token

Please check if it same during training and inference, If yes, sorry I am unable to help solve your problem.

outday29 commented 1 year ago

I had the same issue as you, the model performed very well in validation during training, but very poorly during inference (the model just output gibberish).

Try to check added_tokens.json file in your trained model path. Your task prompt should be among the added tokens. For some reason, my task prompt was just <s_> instead of <s_{DATASET_NAME}>.

added_json
Mohtadrao commented 7 months ago

I am having the same issue as it is giving very good accuracy and prediction during validation however it is giving garbish value while predicting the same image. My added tokens json is fine. Kindly help. @WaterKnight1998 @satheeshkatipomu @outday29

CarlosSerrano88 commented 2 months ago

@Mohtadrao @WaterKnight1998 have you founded a solución? I have the same error. Thanks

Mohtadrao commented 1 month ago

@CarlosSerrano88 yeah. My accuracy was increase when my number of epochs along with training data was increased. Also need to adjust the config file accordingly.

Ruxin124 commented 1 month ago

during training at the begining epoches the predictions are like this, is it normal? i just need to wait 200 epoches finished ot is there some thing wrong? image

Ruxin124 commented 1 month ago

@Mohtadrao @WaterKnight1998 have you founded a solución? I have the same error. Thanks Same problem here, predictions are all (s) (s) (s) .... Have someone solved the problem ? image

Mohtadrao commented 1 month ago

@Ruxin124 This is normal error. You just need to train, I dont know how much of data you have but considering my data that was aprox 7-8k images. For 45 epochs it was already giving me 90% accuracy. Also need to write config file accurately according to your data.