This is the code for the paper: End-to-end Document Recognition and Understanding with Dessurt (https://arxiv.org/abs/2203.16618).
(Accepted to TiE@ECCV 2022)
"We introduce Dessurt, a relatively simple document understanding transformer capable of being fine-tuned on a greater variety of document tasks than prior methods. It receives a document image and task string as input and generates arbitrary text autoregressively as output. Because Dessurt is an end-to-end architecture that performs text recognition in addition to the document understanding, it does not require an external recognition model as prior methods do, making it easier to fine-tune to new visual domains. We show that this model is effective at 9 different dataset-task combinations."
I find it helpful to use pip, not conda, for these:
Also my own module synthetic_text_gen
https://github.com/herobd/synthetic_text_gen needs installed for text generation.
This is the script that executes training based on a configuration file. The training code is found in trainer/
.
The usage is: python train.py -c CONFIG.json [-r CHECKPOINT.pth]
(see configs/
for example configuation files and below for an explaination)
A training session can be resumed with the -r
flag (using the "checkpoint-latest.pth"). This is also the flag for starting from a pre-trained model.
If you want to override the config file on a resume, just use the -c
flag in addition to -r
and be sure the config has "override": true
(all mine do)
The configs
directory has configs for doing the pre-training and fine-tuning of Dessurt.
When fine-tuning, I reset the pre-trained checkpoint using this: python change_checkpoint_reset_for_training.py -c pretrained/checkpoint.pth -o output/directory(or_checkpoint.pth)
This resets the iteration count and optimizer and automatically names the output "checkpoint-latest.pth" so you can start training from it with the -r
flag.
If you resume training from a snapshot with different shaped weight tensors (or extra or missing weight tensors) the base_trainer will cut and paste weights to make things work (with random initialization for new weights). This is particularly useful in defining new tokens (no problem) or resizing the input image (if it's smaller you may not even need to fine-tune).
You can override the GPU specified in the config file using the -g
flag (including a-g -1
to use CPU).
This allows an interactive running of Dessurt.
Usage: python run.py -c CHECKPOINT.pth [-i image/path(default ask for path)] [-g gpu#] [-a add=or,change=things=in,config=v] [-S (get saliency map)]
It will prompt you for the image path (if not provided) and for the queries. You'll have to draw in highlights and masks when prompted. Here are some helpful task tokens (start of query). Tokens always end with "~" or ">":
re~text to start from
w0>
read_block>
, you can also use read_block0>
if you want to provide the highlightmlm>
mk>
natural_q~question text
json>
json~JSON": "to start from"}
linkdown-both~text of element
or linkdown-box~
or linkdown-text~text of element
For evaluating Dessurt on all datasets other than FUNSD and NAF.
Usage: `python qa_eval.py -c CHECKPOINT.pth -d DATASETNAME [-g GPU#] [-T (do test set)]
The -d
flag is to allow running running a model on a dataset it was not fine-tuned for.
To evaluate the FUNSD/NAF datasets we use three scripts.
funsd_eval_json.py
/naf_eval_json.py
generates the output JSON and produces the Entity Fm and Relationship Fm.
get_GAnTED_for_Dessurt.py
produces the GAnTED score.
The *_eval_json.py
files handle correcting Dessurt's output into valid JSON and aligning the output to the GT for entity detection and linking.
Usage:
python funsd_eval_json.py -c the/checkpoint.pth [-T (do testset)] [-g GPU#] [-w the/output/predictions.json] [-E entityMatchThresh-default0.6] [-L linkMatchThresh-default0.6] [-b doBeamSearchWithThisNumberBeams]
(the same usage for naf_eval_json.py
)
`python get_GAnTED_for_Dessurt.py -p the/predictions.json -d FUNSD/NAF (dataset name) [-T (do testset)] [-P doParallelThisManyThreads] [-2 (run twice)] [-s (shuffle order before)]
This will display with graphs statistics logged during training.
python graph.py -c the/checkpoint.pth -o metric_name
The -o
flag can accept part of the name. Generally the key validation metrics always start with "val_E", the exception being full-page recognition ("val_read_block>") and NER ("val_F_Measure_MACRO").
If you omit the -o
flag it will try to draw all the metrics, but there are too many.
You can also use graph.py to export and checkpoint so it doesn't have the model with the -e output.pth
option.
The current config files expect all datasets to be in a data
directory which is in the same directory the project directory is.
gpt_forms.py
)datasets
(https://huggingface.co/datasets/wikipedia)You first need to setup the data and then a config file. You can see configs/
for a number of example fine-tuning config files.
For setting up the data you have two options. If you can define your dataset as images with a set of queries and text answers, you can use the MyDataset class. If you need something fancier, you can define your own dataset class.
Note: The Swin implementation requires image dimensions to be multiples of 8 and that (dim / 8)%window_size==0. (8 is from 4x downsample from CNN and 2x downsample from Swin downsample. Then needs to able to fit windows evenly)
See configs/cf_test_cats_each_finetune.json
and configs/cf_test_cats_qa_finetune.json
and their respective data in example_data
for an example of how to use MyDataset.
MyDataset expects data_dir
to be a directory with a "train", "valid", and possibly "test" subdirectory.
Each of these are to have the images (nested in subdirectories allowed). Then there either needs to be a json for each image ('this/image.png' and 'this/image.json') or a single 'qa.json'
'this/image.json' has the list of Q-A pairs:
[
{"question": "TOK~context text",
"answer": "text response"},
...
]
"TOK~" will be the special task token string. See the Task Tokens section. Answers can also be a list of strings, such as how DocVQA has multiple right answers.
If you use the 'qa.json' format, it has a map from each image path to that image's list of Q-A pairs
{"an/imagefile.png": [ {"question": "TOK~context text",
"answer": "response text"},
{"question": "TOK2>",
"answer": "other response text"},
...
],
...
}
All of the datasets used in training and evaluating Dessurt are defined as their own class, so you have many examples in data_sets/
Most are descendants of the QADataset (qa.py
) and that is probably going to be the easiest route for you.
A demo colab on using a custom dataset class to train Dessurt on MNIST is available here: TODO add url
The constructor of your child class will need to populate self.images
as an array of dicts with
'imagePath'
: the path to the image, can also be None if the image is returned from self.parseAnn
'imageName'
: Optional, defaults to path'annotationPath
: If this is a path to a json, the json will be read and passed to self.parseAnn
, otherwise whatever this is will be passed to self.parseAnn
Your child class will also need to implement the parseAnn
function, which takes as input the "annotation" and returns:
The bounding boxes/IDs are to allow the QADataset to crop the image and then remove possible QA pairs that have been cropped off of the image. If you aren't cropping, you don't need to worry about it.
To make getting the Query-Answer pairs ready, use the self.qaAdd
function. It can take the lists of box coordinates (either for highlighting or masking) and QADataset will handle everyting for these.
Task tokens are always at the begining of the query string and end with either "~" or ">".
They are defined in model/special_token_embedder.py
. If you need to add some of your own, just add them at the end of the "tokens" list, and that's all you need to do (I guess you can also replace a "not used" token as well).
Most tasks have the model add 'β‘' to the end of what it's reading to make it obvious it has reached the end.
If you are doing the same thing as a pre-training task, it would be helpful to reuse the same task token.
Here's what the current tokens that are used in pre-training are for ( "not used" tokens are defined as tasks in the code, but weren't used in final training):
β
βββ train.py - Training script
βββ qa_eval.py - Evaluation script
βββ run.py - Interactive run script
βββ funsd_eval_json.py - Evaluation for FUNSD
βββ naf_eval_json.py - Evaluation for NAF
βββ get_GAnTED_for_Dessurt.py - Compute GAnTED given json output from one of the above two scripts
βββ graph.py - Display graphs of logged stats given a snapshot
βββ check_checkpoint.py - Print iterations for snapshot
βββ change_checkpoint_reset_for_training.py - Reset iterations and optimizer for snapshot
βββ change_checkpoint_cf.py - change the config of a snapshot
βββ change_checkpoint_rewrite.py - Rearrange state_dict
βββ gpt_forms.py - This script uses GPT2 to generate label-value pair groups
β
βββ base/ - abstract base classes
β βββ base_data_loader.py - abstract base class for data loaders
β βββ base_model.py - abstract base class for models
β βββ base_trainer.py - abstract base class for trainers
β
βββ configs/ - where the config jsons are
β
βββ data_loader/ -
β βββ data_loaders.py - This provides access to all the dataset objects
β
βββ data_sets/ - default datasets folder
β βββ bentham_qa.py - For Bentham question answering dataset
β βββ cdip_cloud_qa.py - IIT-CDIP dataset with handling of downloading/unpacking of parts of the dataset
β βββ cdip_download_urls.csv - Where IIT-CDIP is stored (if I don't find a hosting solution these may be bad)
β βββ census_qa.py - Used for pre-training on FamilySearch's census indexes
β βββ distil_bart.py - Dataset which handles setting everything up for distillation
β βββ docvqa.py - DocVQA dataset
β βββ form_qa.py - Parent dataset for FUNSD, NAF, and synthetic forms
β βββ funsd_qa.py - FUNSD (query-response)
β βββ gen_daemon.py - Handles text rendering
β βββ hw_squad.py - HW-SQuAD
β βββ iam_Coquenet_splits.json - Has the IAM splits used for page recognition
β βββ iam_mixed.py - Mixes 3 IAM pages' word images into two lists. Used for IAM pre-training in NER experiments
β βββ iam_ner.py - IAM NER
β βββ iam_qa.py - IAM page recognition
β βββ iam_standard_splits.json
β βββ iam_Coquenet_splits.json - IAM splits used by "End-to-end Handwritten Paragraph Text Recognition Using a Vertical Attention Network" which we compare against for handwriting recognition
β βββ long_naf_images.txt - Note of images in the NAF training set with long JSON parse (more than 800 characters)
β βββ multiple_dataset.py - Allows training to sample from collection of datasets
β βββ my_dataset.py - Allows code-less definition of custom query-response dataset
β βββ NAF_extract_lines.py - will extract all text/handwriting lines from NAF dataset (to compile a dataset for standard line recognition model)
β βββ naf_qa.py - NAF dataset (query-response)
β βββ naf_read.py - Recognition only on NAF dataset. Special resizing to be sure things are ledgible
β βββ para_qa_dataset.py - Parent dataset for IAM, IIT-CDIP, and synthetic Paragraphs (rendered Wikipedia)
β βββ qa.py - Parent class for all query-response datasets (everything used by Dessurt except distillation dataset)
β βββ record_qa.py - Parent class for Census dataset.
β βββ rvl_cdip_class.py - Classification on RVL-CDIP dataset (query-response)
β βββ squad.py - Font rendered SQuAD
β βββ sroie.py - SROIE key inforamtion retrieval dataset
β βββ synth_form_dataset.py - Synthetic forms dataset. Renders and arranges forms
β βββ synth_hw_qa.py - Synthetic handwriting dataset. Loads pre-sythesized handwriting lines
β βββ synth_para_qa.py - Synthetic Wikipedia dataset. Renders articles with fonts
β βββ wiki_text.py - For loading Wikipedia data (singleton as multiple datset use it)
β βββ wordsEn.txt - List of English words used by para_qa_dataset.py
β β
β βββ graph_pair.py - base class for FUDGE pairing
β βββ forms_graph_pair.py - pairing for NAF dataset
β βββ funsd_graph_pair.py - pairing for FUNSD dataset
β β
β βββ test_*.py - scripts to test the datasets and display the images for visual inspection
β
βββ logger/ - for training process logging
β βββ logger.py
β
βββ model/ - models, losses, and metrics
β βββ attention.py - Defines attention functions
β βββ dessurt.py - The Dessurt model
β βββ loss.py - All losses defined here
β βββ pos_encode.py - poisiton encoding functions
β βββ special_token_embedder.py - Defines task tokens
β βββ swin_transformer.py - Code of Swin Transformer modified for Dessurt
β βββ unlikelihood_loss.py - Not used as it didn't improve results
β
βββ saved/ - default checkpoints folder
β
βββ trainer/ - trainers
β βββ qa_trainer.py - Actual training code. Handles loops and computation of metrics
β
βββ utils/
βββ augmentation.py - brightness augmentation
βββ crop_transform.py - Cropping and rotation augmentation. Also tracks movement and clipping of bounding boxes
βββ filelock.py - Used by CDIPCloud
βββ forms_annotations.py - Helper functions for parsing and preparing NAF dataset
βββ funsd_annotations.py - Helper functions for parsing and preparing FUNSD dataset
βββ GAnTED.py - Defines GAnTED metric
βββ grid_distortion.py - Curtis's warp grid augmentation from "Data augmentation for recognition of handwritten words and lines using a CNN-LSTM network"
βββ img_f.py - image helper functions. Wraps scikit-image behind OpenCV interface
βββ parseIAM.py - Helper functions for parsing IAM XMLs
βββ read_order.py - Determine the read order of text lines (estimate)
βββ saliency_qa.py - Will produce saliency map for input image and tokens
βββ util.py - misc functions
Config files are in .json
format.
Note that in train.py I force the naming convention to be "cf_NAME.json", where NAME is the name in the json. This was to catch various naming errors I often made.
Example:
{
"name": "pairing", # Checkpoints will be saved in saved/name/checkpoint-...pth
"cuda": true, # Whether to use GPU
"gpu": 0, # GPU number. (use -g to override with train.py)
"save_mode": "state_dict", # Whether to save/load just state_dict, or whole object in checkpoint (recommended to use state_dict)
"override": true, # Override a checkpoints config (generally what you want to happen)
"super_computer":false, # Whether to mute training info printed, also changes behavoir or CDIPCloudDataset
"data_loader": {
"data_set_name": "DocVQA", # Class of dataset (many datasets will have their own special parameters, the ones here are general for anything inheriting from QADatset)
"data_dir": "../data/DocVQA", # Directory of dataset
"batch_size": 1,
"shuffle": true,
"num_workers": 4,
"rescale_to_crop_size_first": true, # Resize image to fit in crop_size before applying random rescale with rescale_range. Generally what you want unless you change the model size to fit the data
"rescale_range": [0.9,1.1], # images are randomly resized in this range (scale augmentation)
"crop_params": {
"crop_size":[1152,768], # Crop size (needs to match model image size)
"pad":0,
"rot_degree_std_dev": 1 # Rotation augmentation
}
},
"validation": { # Enherits/copies all values from data_loader, specified values are changed
"shuffle": false,
"batch_size": 3 # Generally can use larger batch size in validation
"rescale_range": [1,1], # No scale augmentation
"crop_params": {
"crop_size":[1152,768], # Crop size (needs to match model image size)
"pad":0,
"random": false # Ensure non-stochastic placement
}
},
"optimizer_type": "AdamW",
"optimizer": { # Any parameters of the optimizer object go here
"lr": 0.0001,
"weight_decay": 0.01
},
"loss": { # Losses are in model/loss.py
"answer": "label_smoothing", # Loss on text output
"mask": "focalLoss" # Loss on pixel mask output
},
"loss_weights": {
"answer": 1,
"mask": 1
},
"loss_params": { # Parameters used my loss functions
"answer": {
"smoothing": 0.1,
"padding_idx": 1
}
},
"metrics": [],
"trainer": {
"class": "QATrainer",
"iterations": 421056, # Number of iterations, not weight update steps
"accum_grad_steps": 64, # How many iterations to accumulate the gradient before weight update
"save_dir": "saved/", # saves in save_dir/name
"val_step": 10000, # Validate every X iterations. Set arbitrary high to turn off validation
"save_step": 2000000, # Every X iterations save "checkpoint-iterationY.pth"
"save_step_minor": 1024, # Every X iterations save "checkpoint-latest.pth"
"log_step": 1024, # Averages metrics over this many iterations
"print_pred_every": 1024, # Prints the Queries, GT answers, and predicted answers
"verbosity": 1,
"monitor": "val_E_ANLS", # Save "model_best.pth" whenever this metric improves
"monitor_mode": "max", # Whether bigger or smaller is better for metric
"retry_count": 0,
"use_learning_schedule": "multi_rise then ramp_to_lower", # use "multi_rise" if LR drop is not needed (ramps the LR from 0 over warmup_steps iterations)
"warmup_steps": [
1000
],
"lr_down_start": 350000, # when LR drop happens for ramp_to_lower
"ramp_down_steps": 10000, # How many iterations to lower LR over
"lr_mul": 0.1 # How much LR is dropped at the end of ramp_down_steps
},
"arch": "Dessurt",
"model": {
"image_size": [
1152,768 # Input image size
],
"window_size": 12, # Swin window size. Swin implementation requires (image size / 8)%window_size==0 (8 is from 4x downsample from CNN and 2x downsample from Swin downsample)
"decode_dim": 768, # Text tokens hidden size
"dim_ff": 3072, # reverse bottleneck on text tokens
"decode_num_heads": 8, # num heads on text tokens
"blocks_per_level": [ # how many layers before and after Swin downsample
4,
6
],
"use_swin": [ # Whether visual tokens are updated at each layer
true,
true,
true,
true,
true,
true,
true,
true,
false,
false
],
"swin_cross_attention": [ # Whether visual tokens cross attend to query
true,
true,
true,
true,
true,
true,
true,
true,
false,
false
],
"swin_nheads": [ # Number of heads for Swin attention before and after Swin downsample
4,
8
],
"im_embed_dim": 128 # Initial image token size (doubled at Swin downsample)
}
}
The checkpoints will be saved in save_dir/name
.
The config file is saved in the same folder. (as a reference only, the config is loaded from the checkpoint)
Note: checkpoints contain:
{
'arch': arch,
'epoch': epoch,
'logger': self.train_logger,
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'monitor_best': self.monitor_best,
'config': self.config
#and optionally
'swa_state_dict': self.swa_model.state_dict() #I didn't find SWA help Dessurt
}