AlibabaResearch / AdvancedLiterateMachinery

A collection of original, innovative ideas and algorithms towards Advanced Literate Machinery. This project is maintained by the OCR Team in the Language Technology Lab, Tongyi Lab, Alibaba Group.
Apache License 2.0
1.49k stars 174 forks source link

Where do I load pre-trained DiT and GiT in the VGT code? #89

Open YuNie24 opened 10 months ago

YuNie24 commented 10 months ago

I downloaded 2 weights for finetining VGT.

When finetuning VGT, Where should I specify the pre-trained ViT and DiT weight file paths? If possible, please specify which file the 2 weight paths are loaded from.

bavo96 commented 7 months ago

I have the same issue, can't find out exactly where the GiT and ViT models are loaded. Can you please help me with this? @yashsandansing @alibaba-oss @Wangsherpa

Harsh19012003 commented 5 months ago

There exist command line arguments including --opts where you can specify additional arguments in form of key value pairs like this
MODEL.WEIGHTS <path to modelweights> or in config file (choosen based on your dataset including D4LA, docbank, doclaynet, publaynet) present in Configs/cascade directory, you can directly specify model weights as follows
WEIGHTS: "<yourdirectorystructure>/<weightfilename>"

NOTE: you can manually download weights for DiT's (base or large) model from DiT repository and for layoutlm's pytorchmodel.bin from huggingface

ritutweets46 commented 2 weeks ago

I have the same issue, can't find out exactly where the GiT and ViT models are loaded. Can you please help me with this? @yashsandansing @alibaba-oss @Wangsherpa

Hi @bavo96, were you able to figure this out?

bavo96 commented 2 weeks ago

Hi @ritutweets46, based on my current understanding, the path to change the ViT pre-trained model is in the MyDetectionCheckpointer class within VGTcheckpointer.py

class MyDetectionCheckpointer(DetectionCheckpointer):
    def _load_model(self, checkpoint: Any) -> _IncompatibleKeys:
        ...
        DiT_checkpoint_state_dict = torch.load("/path/dit-base-224-p16-500k-62d53a.pth", map_location=torch.device("cpu"))["model"]
        ...

and the path to change the GiT pre-trained model is in the VGTTrainer class for training

class VGTTrainer(TrainerBase):
    ...
    def resume_or_load(self, resume=True):
        ...
        self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
        ...

and also in the DefaultPredictor class for inference

class DefaultPredictor:
    def __init__(self, cfg):
        ...
        checkpointer.load(cfg.MODEL.WEIGHTS)
        ...

Please feel free to correct me if I’m mistaken :D

ritutweets46 commented 2 weeks ago

Thank you @bavo96 !