herobd / dessurt

Official implementation for Dessurt
MIT License
56 stars 8 forks source link

Inference vs training phase discrepancy #13

Open jlerouge opened 10 months ago

jlerouge commented 10 months ago

Hi @herobd,

I'm trying to finetune a Dessurt model on my own VQA task (predicting a few fields on proof of address documents, like the name of the person, his/her address, city, zip code, ...).

I've set "print_pred_every": 100 to control how the model behave during training phase. While not perfect, the model seems to give answers near the ground-truth in training phase, e.g.

iter 498800
0 [Q]:natural_q~Quelle est la ville du consommateur ?   [A]:MONTIGNE LE BRILLANT    [P]:STTRENY LE BRILLANT
Train iteration: 498800,           mean abs grad: 0.000,
      loss: 0.466,            answerLoss: 7.456,
      score_ed: 0.308,            natural_q~_ED: 0.308,
      sec_per_iter: 0.593,            avg_mean abs grad: 0.000,
      avg_loss: 0.257,            avg_answerLoss: 4.117,
      avg_score_ed: 0.347,            avg_natural_q~_ED: 0.347,

(...)

iter 499400
0 [Q]:natural_q~Quelle est la ville du consommateur ?   [A]:AIGREFEUILLE    [P]:BEIGREFEUILLE
Train iteration: 499400,           mean abs grad: 0.000,
      loss: 0.205,            answerLoss: 3.288,
      score_ed: 0.160,            natural_q~_ED: 0.160,
      sec_per_iter: 0.589,            avg_mean abs grad: 0.000,
      avg_loss: 0.226,            avg_answerLoss: 3.621,
      avg_score_ed: 0.310,            avg_natural_q~_ED: 0.310,

However, when using the latest weights in prediction mode using run.py script, I have very different results. For example, on the training sample with the right answer being "MONTIGNE LE BRILLANT" for the city question, here's the result.

>>> main("/home/qsuser/src/dessurt/saved/dessurt_qs_dom_qa_fra_finetune/checkpoint-iteration500000.pth", "/home/qsuser/src/dessurt/configs/cf_dessurt_qs_dom_qa_finetune.json", "/home/qsuser/Work/ProofOfAddress/Data/JDD_2023_05_09/images_dessurt_questions_fra/train/875fe8f4c8e3ad9e8559956a3fcbf058/image.jpg", [], True, default_task_token="natural_q~", dont_output_mask=False)
loaded dessurt_qs_dom_qa_fra_finetune iteration 500000
Using default task token: natural_q~
 (if another token is entered with the query, the default is overridden)
Query: Quelle est la ville du consommateur ?
Answer: ST PIERRE DES CORPS

The model clearly hallucinates an answer, and the output mask seems to be completely random output (the answer is located in the upper right corner, inside the address block). image

I have no clue why it seems to produce better answers during training phase. Do you have any idea ?

I'm also sharing my configuration file for reference: cf_dessurt_qs_dom_qa_finetune.json

I would be thankful for any help on training dessurt :)

herobd commented 10 months ago

You config shows the number of iterations set to 73k, but you're the samples you put above are at almost 500k iterations. Are you sure this is the right config getting used?

jlerouge commented 10 months ago

Oh, you're right, I haven't sent the right config file, but the diff doesn't really affect my issue I guess. Here is the correct configuration file : cf_dessurt_qs_dom_qa_fra_finetune.json

I have 735 training samples, 91 validation samples (and 91 test samples...). I plan on having more labeled samples but this is not ready yet.

I wonder if this training experiment is going somewhere near convergence or not. The loss still seems to be high during the last iterations... May I have success with even more iterations ? Maybe with a different learning rate strategy ?

And my main interrogation is : why the model seems to produce different results while training and in test phase ?

I believe the ground-truth label is used somehow as an input of the model during training to help the model producing its text answer, but I can't prove it since I still lack a bit of understanding on what you've done. It may be related to the RUN parameter and what you call "teacher-forcing". This is probably where the difference I see happens : https://github.com/herobd/dessurt/blob/main/model/dessurt.py#L289

So, maybe my issue is that I lack of training samples, and the loss is still way too high.

Anyway, thanks for the hard work :+1:

herobd commented 10 months ago

You can train as long as the validation ED doesn't drop. 0.310 seems a bit high. You can use graph.py to see what it's doing. You're doing French documents? Dessurt was pre-trained on only English, so you may have a roadblock there. Particularly because the tokenization is English.

I'm not sure why using run.py would drop performance. Do you have any issues resuming training from a checkpoint?