bdaiinstitute / theia

Theia: Distilling Diverse Vision Foundation Models for Robot Learning
http://theia.theaiinstitute.com
Other
188 stars 7 forks source link

Change the student backbone #16

Open zhlhlhlhl opened 1 month ago

zhlhlhlhl commented 1 month ago

Hi, I'm curious about if Theia can change the backbone of the student model. Now you are using DEIT-base-patch-16-224, I want to use a pre-trained larger model like clip-vit-base-patch32 to distill other VFMs. Do you think it's feasible?

jshang-bdai commented 1 month ago

Hi @zhlhlhlhl

Go ahead! It's highly flexible.

  1. Add any backbone here https://github.com/bdaiinstitute/theia/blob/main/src/theia/models/backbones.py. Check the spatial size of your backbone's output and see if you need to modify feature translators or create a new one
  2. Create a corresponding backbone config file here https://github.com/bdaiinstitute/theia/tree/main/src/theia/configs/model/backbone. If you create a new feature translator, also add the translator configuration here https://github.com/bdaiinstitute/theia/tree/main/src/theia/configs/model/translator
  3. Modify your overall training config file to use your new backbone.
zhlhlhlhl commented 1 month ago

Really appreciate your prompt responese! I'll try this ASAP. BTW, do you think the feature generated by a larger student backbone will perform better in vision tasks?

zhlhlhlhl commented 1 month ago

After training, I only got .pth files. How do I load these weights?

jshang-bdai commented 1 month ago

larger student backbone for better vision tasks

It's highly possible!

load weights

See https://github.com/bdaiinstitute/theia/blob/9ee7548829088e1a7dae6a033dfb6b520656c1f2/src/theia/models/rvfm.py#L77

zhlhlhlhl commented 1 month ago

I mean, if I want to use AutoModel.from_pretrained to load the model, it lacks files such as config.json and theia_model.py, where I can get these?

jshang-bdai commented 1 month ago

If I understand it correctly, you want to use AutoModel.from_pretrained but to load local weights. Here are the steps:

  1. Create the model use model = AutoModel.from_pretrained(). Whatever weights.
  2. Call model.load_pretrained_weights(<checkpoint_path>). The same as https://github.com/bdaiinstitute/theia/blob/9ee7548829088e1a7dae6a033dfb6b520656c1f2/src/theia/models/rvfm.py#L77