jpWang / LiLT

Official PyTorch implementation of LiLT: A Simple yet Effective Language-Independent Layout Transformer for Structured Document Understanding (ACL 2022)
MIT License
342 stars 40 forks source link

Pre-training code? #30

Open logan-markewich opened 1 year ago

logan-markewich commented 1 year ago

Are you able to provide the pre-training code?

I would like to try and pre-train using roberta-large, or a similar language model :)

jordanparker6 commented 1 year ago

I would like to do the same but with bigbird-roberta-en-base if possible...

logan-markewich commented 1 year ago

If the hidden size is the same as roberta-base, you can probably use the weight generation script in the repo

MaveriQ commented 1 year ago

@logan-markewich, @jordanparker6 I am coding up the collator for the masking in the three pretraining strategies. Maybe we can work together, and share it here afterwards for everyone else to use?

jordanparker6 commented 1 year ago

@logan-markewich, @jordanparker6 I am coding up the collator for the masking in the three pretraining strategies. Maybe we can work together, and share it here afterwards for everyone else to use?

Happy to help out as needed.

jordanparker6 commented 1 year ago

If the hidden size is the same as roberta-base, you can probably use the weight generation script in the repo

I don't think it is... I posted my error message on a seperate issue.


I was able to use the provided script to create a lilt-roberta-base-en using the following: https://huggingface.co/google/bigbird-roberta-base. If I can get this working, I will post up to HuggingfaceHub.

BigBird uses the same tokenizer as roberta so no issue with tokenizationgoogle/bigbird-roberta-base.

However, the following error occurs when loading the model.

RuntimeError: Error(s) in loading state_dict for LiltForTokenClassification:
    size mismatch for lilt.layout_embeddings.box_position_embeddings.weight: copying a param with shape torch.Size([514, 192]) from checkpoint, the shape in current model is torch.Size([4096, 192]).
    You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

I think this error is created when the pytorch state dicts are fused with the following line.

total_model = {**text_model, **lilt_model}

The lilt_model dim changes the incoming bigbird dim.

Would it be problematic to switch this:

total_model = {**lilt_model, **text_model }