Open Davidc2525 opened 5 months ago
Hey,
Given a list of triplets, you need to define a dataset for them. You can refer to these two examples to get an idea of how to define custom datasets. Once that is done, you can setup a dataloader for this dataset, and use it in your training loop. You can use torch.nn.TripletMarginLoss for the loss function.
However, if you are using training.pass_epoch, do note that it expects to unpack a batch of single images at a time (see this). You will have to modify it to unpack a batch of tuple of three images and evaluate each image on the model individually, before passing on to the loss function.
Hope that helps!
Any updates?
Can you leave an example of how to do finetuning with triplet loss?