allenai / vila

Incorporating VIsual LAyout Structures for Scientific Text Classification
Apache License 2.0
173 stars 16 forks source link

Better predict api design #21

Closed lolipopshock closed 2 years ago

lolipopshock commented 2 years ago

This PR introduces two improvements in the predict apis:

  1. It can specify the return types of the predict API -- either layoutparser Layout or just a list of category predictions. This aims to make the API more generalizable and can support downstream uses like mmda
  2. It adds the predict_page API that is dedicated for the vila datamodels. The prediction process is further simplified into one line:

    for idx, page_token in enumerate(page_tokens):
    
    # New
    predicted_tokens1 = pdf_predictor.predict_page(
        page_token, page_image=page_images[idx], visual_group_detector=vision_model
    )
    
    # Previous
    blocks = vision_model.detect(page_images[idx])
    page_token.annotate(blocks=blocks)
    pdf_data = page_token.to_pagedata().to_dict()
    predicted_tokens2 = pdf_predictor.predict(pdf_data, page_token.page_size)
    
    assert predicted_tokens1 == predicted_tokens2