TakeruEndo / kaggle_Cassava

0 stars 1 forks source link

3. VisionTransformer #3

Open TakeruEndo opened 3 years ago

TakeruEndo commented 3 years ago

使い方: timmから使える

class ViTBase16(nn.Module):
    def __init__(self, n_classes, pretrained=False):
        super(ViTBase16, self).__init__()
        self.model = timm.create_model("vit_base_patch16_224", pretrained=False)
        if pretrained:
            self.model.load_state_dict(torch.load(MODEL_PATH))
        self.model.head = nn.Linear(self.model.head.in_features, n_classes)
    def forward(self, x):
        x = self.model(x)
        return x

notebook: https://www.kaggle.com/abhinand05/vision-transformer-vit-tutorial-baseline