insitro / ChannelViT

Channel Vision Transformers: An Image Is Worth C x 16 x 16 Words
https://arxiv.org/abs/2309.16108
Other
47 stars 6 forks source link

TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule` #8

Closed JenniferTsang911 closed 5 months ago

JenniferTsang911 commented 6 months ago

Hi there,

Thank you so much for this model. I am trying to use this model on a separate dataset and would love to make use of the recent pre-trained models you've uploaded. I am struggling to figure out how to load these models and each time I encounter this type error.

TypeError: model must be a LightningModule or torch._dynamo.OptimizedModule , got ChannelVisionTransformer

Please let me know how to exactly load these trained models, I would be greatly appreciated.

Thanks!

srinivasans-insitro commented 6 months ago

Hi,

If you're using the pre-trained models for inference, you can load them using torch hub as described here:

model = torch.hub.load('insitro/ChannelViT', 'imagenet_channelvit_small_p16_with_hcs_supervised', pretrained=True)

From the error message it looks like you're trying to use the pre-trained models for fine-tuning with your data. You can do that like in the following example

# in main_supervised.py 

from torch import hub

...

model = Supervised(cfg)
model_state_dict = hub.load_state_dict_from_url(
    "https://github.com/insitro/ChannelViT/releases/download/v1.0.0/imagenet_channelvit_small_p16_with_hcs_supervised.pth",
    progress=True
)
model.backbone.load_state_dict(model_state_dict)

You can find the details of other model checkpoint URLs in https://github.com/insitro/ChannelViT/blob/main/hubconf.py

Happy to help if you have any further questions