poloclub / unitable

UniTable: Towards a Unified Table Foundation Model
https://arxiv.org/abs/2403.04822
MIT License
276 stars 15 forks source link

add gptfast decoder #11

Open Sanster opened 1 month ago

Sanster commented 1 month ago

Add decoder with static kv-cache from gpt-fast. Manually checked the results of the images in dataset/mini_pubtabnet/val, but have not actually run the acc/TEDS metrics on the test set.

Benchmark cell detection model with:

image

The main modifications in full_pipeline.ipynb.

  1. Specify which class to use for the decoder:
backbone = ImgLinearBackbone(d_model=d_model, patch_size=patch_size)
encoder = Encoder(
    d_model=d_model,
    nhead=nhead,
    dropout = dropout,
    activation="gelu",
    norm_first=True,
    nlayer=12,
    ff_ratio=4,
)
# decoder_class = Decoder
decoder_class = GPTFastDecoder
  1. Initialize the decoder in load_vocab_and_model, call map_state_dict when using GPTFastDecoder.
def load_vocab_and_model(..., decoder_class: Type[nn.Module]):
    decoder = decoder_class(
         d_model=d_model,
         nhead=nhead,
         dropout = dropout,
         activation="gelu",
         norm_first=True,
         nlayer=4,
         ff_ratio=4,   
    )
    model = EncoderDecoder(
        backbone=backbone,
        encoder=encoder,
        decoder=decoder,
        ...
    )

    state_dict = torch.load(model_weights, map_location="cpu")
    if isinstance(model.decoder, GPTFastDecoder):
        state_dict = map_state_dict(state_dict)

    model.load_state_dict(state_dict)
    model = model.to(device)
    return vocab, model
  1. In autoregressive_decode, if GPTFastDecoder is used, setup_caches needs to be called first.
def autoregressive_decode(...):
    model.eval()
    is_gpt_fast = isinstance(model.decoder, GPTFastDecoder)
    if is_gpt_fast:
        with torch.device(image.device):
            model.decoder.setup_caches(max_batch_size=image.shape[0], max_seq_length=max_decode_len, dtype=image.dtype)
    memory = model.encode(image)
    ...