adambielski / siamese-triplet

Siamese and triplet networks with online pair/triplet mining in PyTorch
BSD 3-Clause "New" or "Revised" License
3.11k stars 633 forks source link

OnlineTripletLoss #50

Closed heet2201 closed 4 years ago

heet2201 commented 4 years ago

Hello, I am training on custom data using OnlineTripletLoss, But during loss calculation, I got the following error.

TypeError: forward() takes 3 positional arguments but 4 were given

Screenshot_2020-04-26_04-16-40

And all feature Functions are as below,

Screenshot_2020-04-26_04-32-17

Can Anyone help with it !! Thank You.

adambielski commented 4 years ago

You should be using EmbeddingNet with OnlineTripletLoss, not TripletNet. See some examples in the notebooks.

Trotts commented 3 years ago

Hi @adambielski, sorry for reviving a dead issue but I was wondering if you could explain a bit more why we use EmbeddingNet rather than TripletNet when using OnlineTripletLoss? I ran into the same issue as @heet2201 whereby I used a TripletNet and received the same error before going back to the example notebook and realising my mistake. However, looking through the code I can't seem to find the explanation as to why EmbeddingNet is used rather than TripletNet when using OnlineTripletLoss and a Triplet selector, surely we should be using the predefined TripletNetwork here? Have I missed something obvious?

adambielski commented 3 years ago

@Trotts In my implementation TripletNetwork takes a triplet as an input - an anchor, a positive and a negative and returns a triplet of embeddings in the same order; so we need to sample the triplets before feeding them to the network and then we can compute TripletLoss on those specific triplets. All that TripletNetwork does is running the EmbeddingNet for anchors, positive and negative inputs.

If we want to use OnlineTripletLoss, we do not sample the triplest before feeding them to the network, we simply get the embeddings for a batch of images (that's why we use EmbeddingNet that only takes one input and returns one output) and use their labels (in a triplet selector) to create triplets afterwards and compute the loss on the computed embeddings. Now we can use one embedding to compute the loss for multiple triplets.

Trotts commented 3 years ago

@adambielski thank you for the swift response! That makes perfect sense, I forgot that images have to be embedded before triplets can be selected, in which case using the EmbeddingNet is obvious!