facebookresearch / deit

Official DeiT repository
Apache License 2.0
4.02k stars 552 forks source link

Can't load Cait model from torch hub #126

Closed AlessioGalluccio closed 2 years ago

AlessioGalluccio commented 2 years ago

I tried to load a cait model using this code

feature_extractor = torch.hub.load('facebookresearch/deit:main', 'cait_M48', pretrained=True)

However, I get this error

raise RuntimeError('Cannot find callable {} in hubconf'.format(model))

I think that this error could be solved by adding from cait_models import * to the hubconf.py, in order to make cait_M48 and the other cait models callable

TouvronHugo commented 2 years ago

Hi @AlessioGalluccio, Thanks for your remark, I have modified the hubconf.py file. Best, Hugo

AlessioGalluccio commented 2 years ago

Thank you. Unfortunately, the code gives an error while loading the hubconf,py file now. Fortunately, I found that it just needs a simple fix.

Firstly, in cait_models,py there is a "cait_M4" which is not implemented. The fix is to eliminate it in the __all__variable. Like this:

__all__ = [ 'cait_M48', 'cait_M36', 'cait_S36', 'cait_S24','cait_S24_224', 'cait_XS24','cait_XXS24','cait_XXS24_224', 'cait_XXS36','cait_XXS36_224' ]

Similarly, names in the __all__ variable are not correctly typed in the file resmlp_models.py. I fixed it by writing this: __all__ = [ 'resmlp_12', 'resmlp_24', 'resmlp_36', 'resmlpB_24' ]

TouvronHugo commented 2 years ago

Thanks for pointing out the typo. I have fixed it. Best, Hugo