Open Sanster opened 1 month ago
Add decoder with static kv-cache from gpt-fast. Manually checked the results of the images in dataset/mini_pubtabnet/val, but have not actually run the acc/TEDS metrics on the test set.
dataset/mini_pubtabnet/val
acc/TEDS
Benchmark cell detection model with:
The main modifications in full_pipeline.ipynb.
full_pipeline.ipynb
backbone = ImgLinearBackbone(d_model=d_model, patch_size=patch_size) encoder = Encoder( d_model=d_model, nhead=nhead, dropout = dropout, activation="gelu", norm_first=True, nlayer=12, ff_ratio=4, ) # decoder_class = Decoder decoder_class = GPTFastDecoder
load_vocab_and_model
map_state_dict
GPTFastDecoder
def load_vocab_and_model(..., decoder_class: Type[nn.Module]): decoder = decoder_class( d_model=d_model, nhead=nhead, dropout = dropout, activation="gelu", norm_first=True, nlayer=4, ff_ratio=4, ) model = EncoderDecoder( backbone=backbone, encoder=encoder, decoder=decoder, ... ) state_dict = torch.load(model_weights, map_location="cpu") if isinstance(model.decoder, GPTFastDecoder): state_dict = map_state_dict(state_dict) model.load_state_dict(state_dict) model = model.to(device) return vocab, model
autoregressive_decode
setup_caches
def autoregressive_decode(...): model.eval() is_gpt_fast = isinstance(model.decoder, GPTFastDecoder) if is_gpt_fast: with torch.device(image.device): model.decoder.setup_caches(max_batch_size=image.shape[0], max_seq_length=max_decode_len, dtype=image.dtype) memory = model.encode(image) ...
Add decoder with static kv-cache from gpt-fast. Manually checked the results of the images in
dataset/mini_pubtabnet/val
, but have not actually run theacc/TEDS
metrics on the test set.Benchmark cell detection model with:
The main modifications in
full_pipeline.ipynb
.load_vocab_and_model
, callmap_state_dict
when usingGPTFastDecoder
.autoregressive_decode
, ifGPTFastDecoder
is used,setup_caches
needs to be called first.