adambielski / siamese-triplet

Siamese and triplet networks with online pair/triplet mining in PyTorch
BSD 3-Clause "New" or "Revised" License
3.11k stars 633 forks source link

How to use this siamese model for classification task? #66

Open yuhengShii opened 3 years ago

yuhengShii commented 3 years ago

Hi, If i have trained a siamese model, how can i use this model for my classification task?

Change the embedding_net into siamese_net and call its get_embedding to get feature vectors , is the following code correct? anticipating your reply.

siamese_net = EmbeddingNet()
class ClassificationNet(nn.Module):
    def __init__(self, siamese_net , n_classes):
        super(ClassificationNet, self).__init__()
        self.siamese_net = siamese_net 
        self.n_classes = n_classes
        self.nonlinear = nn.PReLU()
        self.fc1 = nn.Linear(2, n_classes)

  def forward(self, x):
      output = self.siamese_net.get_embedding (x)
      output = self.nonlinear(output)
      scores = F.log_softmax(self.fc1(output), dim=-1)
      return scores