aspamers / siamese

A simple, easy-to-use and flexible siamese neural network implementation for Keras
MIT License
65 stars 16 forks source link

how to train the same network using euclidean distance?? #3

Closed balajib363 closed 4 years ago

balajib363 commented 4 years ago

Actually I am trying to predict the similarity between two images based on the predicted feature vector. Was trying to change the final layer head with distance ie distance = Lambda(self.euclidean_distance, output_shape=self.eucl_dist_output_shape)([processed_a, processed_b]). Let me know if you can help. Thanks

aspamers commented 4 years ago

Hi Balaji. Appologies for not responding till now I didn't set up notifications for the repo.

You can train based on euclidian distance by setting the head model to a static equation using the keras lambda layer

Disclaimer: Did not run this code

def create_head_model(embedding_shape):
    embedding_a = Input(shape=embedding_shape)
    embedding_b = Input(shape=embedding_shape)
    head = Lambda(self.euclidean_distance, output_shape=self.eucl_dist_output_shape)([embedding_a, embedding_b])
    head = Dense(4)(head)
    head = BatchNormalization()(head)
    head = Activation(activation='sigmoid')(head)
    head = Dense(1)(head)
    head = BatchNormalization()(head)
    head = Activation(activation='sigmoid')(head)
    return Model([embedding_a, embedding_b], head)

Note that some extra layers have to be used after the distance function for the network to converge. This is because in this design the rule for learning the network is not explicitly coded but rather left up to the network to learn on its own.

In a more traditional design for a siamese neural network a contrastive loss function is used to explicitly define that rule rather than leaving it to be learnable. That might be what you are after and its not something that is directly supported by this code at the moment.

In my experience however the siamese network as-is performs quite well when used for similarity measurements between images. That is one of the tasks I use this code for privately.