martinsbruveris / tensorflow-image-models

TensorFlow port of PyTorch Image Models (timm) - image models with pretrained weights.
https://tfimm.readthedocs.io/en/latest/
Apache License 2.0
286 stars 25 forks source link

Tfimm for feature extraction #98

Open AmericaBG opened 1 year ago

AmericaBG commented 1 year ago

Hi! First of all, thank you very much for your work, I think it's a very useful tool for those of us who are better at dealing with tensorflow :)

I would like to use some pretrained transformers model to extract features before the final layer of classification, so my question is: What should I do to get it? All I have to do is create a pretrained model with nclases=0?

Thank you very much!! 😊

SEOYUNJE commented 2 weeks ago

Yes, All you have to do is just set nclasses=0

 base_model =  tfimm.create_model("vit_base_patch16_224", pretrained=True, in_channels=3, nb_classes=0)
 x = base_model(inp)
 x = layers.Dense(len(TARGET), activation='softmax', dtype='float32')(x)