Our pretraining dataset is divided into two primary components:
We release the checkpoints for our pretrained models as well as the finetuned checkpoints on the different downstream tasks | Task | Checkpoint Path |
---|---|---|
Pretrained | unichart-base-960 | |
ChartQA | unichart-chartqa-960 | |
Chart2Text-Statista | unichart-chart2text-statista-960 | |
Chart2Text-Pew | unichart-chart2text-pew-960 | |
OpenCQA | unichart-opencqa-960 |
If you wish to quickly try our models, you can access our public web demoes hosted on the Hugging Face Spaces platform with a friendly interface!
Tasks | Web Demo |
---|---|
Base Model (Best for Chart Summarization and Data Table Generation) | UniChart-Base |
Chart Question Answering | UniChart-ChartQA |
The input prompt for Chart summarization is
transformers==4.28.1
pytorch-lightning==1.8.5
datasets
sentencepiece
Please make sure to use the exact same version of the Transformers library. We have noticed that there might be a drop in performance when using different versions of the library!
You can easily use our models for inference with the huggingface library! You just need to do the following:
Task | Input Prompt |
---|---|
Chart Question Answering | \<chartqa> question |
Open Chart Question Answering | \<opencqa> question |
Chart Summarization | |
Data Table Extraction |
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch, os, re
torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_1229.png', 'chart_example_1.png')
model_name = "ahmed-masry/unichart-chartqa-960"
image_path = "/content/chart_example_1.png"
input_prompt = "<chartqa> What is the lowest value in blue bar? <s_answer>"
model = VisionEncoderDecoderModel.from_pretrained(model_name)
processor = DonutProcessor.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
image = Image.open(image_path).convert("RGB")
decoder_input_ids = processor.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids
pixel_values = processor(image, return_tensors="pt").pixel_values
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=4,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = sequence.split("<s_answer>")[1].strip()
print(sequence)
In order to finetune the model on the ChartQA dataset, you can edit and run the following command:
python finetune_chartqa.py --data-path "ahmed-masry/chartqa_without_images" --train-images '/content/ChartQA/ChartQA Dataset/train/png/' \
--valid-images '/content/ChartQA/ChartQA Dataset/val/png' --max-steps 40000 --batch-size 8 --valid-batch-size 1 --num-workers 12 --lr 5e-5 \
--check-val-every-n-epoch 1 --warmup-steps 100 --checkpoint-steps 7000 --checkpoint-path "ahmed-masry/unichart-base-960"
If you have any questions about this work, please contact Ahmed Masry using the following email addresses: amasry17@ku.edu.tr or ahmed.elmasry24653@gmail.com.
Please cite our paper if you use our models or dataset in your research.
@misc{masry2023unichart,
title={UniChart: A Universal Vision-language Pretrained Model for Chart Comprehension and Reasoning},
author={Ahmed Masry and Parsa Kavehzadeh and Xuan Long Do and Enamul Hoque and Shafiq Joty},
year={2023},
eprint={2305.14761},
archivePrefix={arXiv},
primaryClass={cs.CL}
}