vis-nlp / UniChart

MIT License
64 stars 9 forks source link

UniChart: A Universal Vision-language Pretrained Model for Chart Comprehension and Reasoning

UniChart Pretraining Dataset

Our pretraining dataset is divided into two primary components:

  1. A zip file encompassing all the images. You can access the images through this huggingface dataset: Images
  2. A Huggingface dataset containing the input/output pairs utilized for model pretraining. You can find the dataset here: Huggingface Dataset

UniChart Model Checkpoints

We release the checkpoints for our pretrained models as well as the finetuned checkpoints on the different downstream tasks Task Checkpoint Path
Pretrained unichart-base-960
ChartQA unichart-chartqa-960
Chart2Text-Statista unichart-chart2text-statista-960
Chart2Text-Pew unichart-chart2text-pew-960
OpenCQA unichart-opencqa-960

Web Demo

If you wish to quickly try our models, you can access our public web demoes hosted on the Hugging Face Spaces platform with a friendly interface!

Tasks Web Demo
Base Model (Best for Chart Summarization and Data Table Generation) UniChart-Base
Chart Question Answering UniChart-ChartQA

The input prompt for Chart summarization is and Data Table Generation is

Requirements

transformers==4.28.1
pytorch-lightning==1.8.5
datasets
sentencepiece

Please make sure to use the exact same version of the Transformers library. We have noticed that there might be a drop in performance when using different versions of the library!

Inference

You can easily use our models for inference with the huggingface library! You just need to do the following:

  1. Change _modelname to your prefered checkpoint.
  2. Chage the _imagpath to your chart example image path on your system
  3. Write the _inputprompt based on your prefered task as shown in the table below.
Task Input Prompt
Chart Question Answering \<chartqa> question
Open Chart Question Answering \<opencqa> question
Chart Summarization
Data Table Extraction
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch, os, re

torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_1229.png', 'chart_example_1.png')

model_name = "ahmed-masry/unichart-chartqa-960"
image_path = "/content/chart_example_1.png"
input_prompt = "<chartqa> What is the lowest value in blue bar? <s_answer>"

model = VisionEncoderDecoderModel.from_pretrained(model_name)
processor = DonutProcessor.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

image = Image.open(image_path).convert("RGB")
decoder_input_ids = processor.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids
pixel_values = processor(image, return_tensors="pt").pixel_values

outputs = model.generate(
    pixel_values.to(device),
    decoder_input_ids=decoder_input_ids.to(device),
    max_length=model.decoder.config.max_position_embeddings,
    early_stopping=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    eos_token_id=processor.tokenizer.eos_token_id,
    use_cache=True,
    num_beams=4,
    bad_words_ids=[[processor.tokenizer.unk_token_id]],
    return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = sequence.split("<s_answer>")[1].strip()
print(sequence)

Finetuning

In order to finetune the model on the ChartQA dataset, you can edit and run the following command:

python finetune_chartqa.py --data-path "ahmed-masry/chartqa_without_images" --train-images '/content/ChartQA/ChartQA Dataset/train/png/' \
    --valid-images '/content/ChartQA/ChartQA Dataset/val/png' --max-steps 40000 --batch-size 8 --valid-batch-size 1 --num-workers 12 --lr 5e-5 \
    --check-val-every-n-epoch 1 --warmup-steps 100 --checkpoint-steps 7000 --checkpoint-path "ahmed-masry/unichart-base-960"

Contact

If you have any questions about this work, please contact Ahmed Masry using the following email addresses: amasry17@ku.edu.tr or ahmed.elmasry24653@gmail.com.

Reference

Please cite our paper if you use our models or dataset in your research.

@misc{masry2023unichart,
      title={UniChart: A Universal Vision-language Pretrained Model for Chart Comprehension and Reasoning}, 
      author={Ahmed Masry and Parsa Kavehzadeh and Xuan Long Do and Enamul Hoque and Shafiq Joty},
      year={2023},
      eprint={2305.14761},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}