pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.14k stars 6.94k forks source link

Triplet Image Dataset #1042

Closed dakshjotwani closed 5 years ago

dakshjotwani commented 5 years ago

Hi,

I am currently working with semi-supervised learning to implement face recognition. I required a dataset that could load triplets (three images, where the first two belong to the same class, and the last image belongs to a different class. The facenet paper has more details on this.) for training. Currently, I have implemented this dataset for my own use. I was wondering if this is a feature that would be nice to have as part of the library itself.

Thanks,

Daksh

fmassa commented 5 years ago

Hi,

I believe triplet (or n-uple) datasets could potentially be handled generically, with the pairwise/triplet/etc combinations generated on the fly, and making the dataset n ** 3 in size (with some potential pruning on the undesired cases).

Thoughts?

dakshjotwani commented 5 years ago

That sounds like a good idea!

We will, however, need to come up with some sort of strategy to select valid k-tuples. The triplet I described has a specific format: the first two entries belong to the same class, while the last entry belongs to a different class. One of the ways we can solve this problem is by allowing the user to pass in a filter function as an argument, which we can then use to select valid k-tuples. How does that sound?

The other problem is that we will have to go through n ** k potential tuples from the dataset to construct the k-tuple dataset, which is quite large and could take forever to run. To avoid this, my current approach to generate triplets is to pick them randomly from the dataset:

classes = ['class_1', 'class_2', ..., 'class_m']
class_imgs = {'c_1': ['/path/1', '/path/2', ...], ...}
num_triplets = 10000

samples = []
for _ in range(num_triplets):
    c_1, c_2 = np.random.choice(classes, size=2, replace=False)
    img_1, img_2 = np.random.choice(class_imgs[c_1], size=2, replace=False)
    img_3 = np.random.choice(class_imgs[c_2])
    triplet = (img_1, img_2, img_3)
    samples.append(triplet)

This approach solved my problem, but may not be suitable for the generic case that you proposed. What do you think?

fmassa commented 5 years ago

@dakshjotwani I'd probably want to have the dataset triplets to be deterministic, at least for torchvision.

So randomly sampling the triplets on a given idx is not something I was expecting to do, because it breaks the property that dataset[idx] gives you always the same thing

fmassa commented 5 years ago

Another option would be to make it an IterableDataset, instead of a standard dataset. In this case, I would be ok having the __getitem__ implement what you proposed.

dakshjotwani commented 5 years ago

@fmassa In my current implementation, dataset[idx] == samples[idx], so it will be deterministic for each object instance. However, two instances of the triplet dataset acting on the same directory will most likely not be equal at the same index (dataset1[idx] != dataset2[idx]). Would this be fine for a VisionDataset or would this still be an IterableDataset?

I'm not sure if it's practical to go through all possible k-tuples, since that would require going through n ** k potential tuples. Maybe I do not see what you mean by 'generated on the fly'?

fmassa commented 5 years ago

@dakshjotwani that's what I mentioned wrt it not being deterministic: two different runs will have different values, and are indeed different datasets.

About the second point: If we pass a groups list to the constructor of TripletDataset(dataset, groups), which contains either False or True, and each element is sampled such that there is True, True, False, then all we need to do is to go over the list during the constructor. This is still n ** k, but depending on the structure of groups, it might be possible to have it done with an equation instead of having to generate the triplet list in memory.

But this might just be a better fit for the IterableDataset as I mentioned, and it would definitely simplify a few things in the implementation

dakshjotwani commented 5 years ago

Alright, I agree that this dataset should be an IterableDataset, and that it would make implementing it much easier as well. I have some final questions regarding implementation before I get to implementing it:

  1. Should I start with TripletDataset or implement KTupleDataset?
  2. If KTupleDataset, what kind of args should I ask for to specify the format? Currently I was thinking of something like KTupleDataset(path, format) where format is a dictionary which looks like this:
    format = {
    'class_0': {
        'num_samples': 3,
        'replacement': False,
    },
    'class_1': {
        'num_samples': 2,
        'replacement': True,
    },
    ...
    }

    The __iter__ method of the KTupleDataset above will return tuples of the form (c0, c0, c0, c1, c1, ...).

  3. Since it will be an IterableDataset, I'm assuming that num_samples will not have to be specified in the constructor? The __iter__ method would just return the next randomly generated triplet from the dataset directory.
  4. Currently there is no pytorch equivalent of np.random.choice. Would it be okay to use this method in my implementation?
fmassa commented 5 years ago

@dakshjotwani I think that it might make sense to start with TripletDataset. I'm not clear on what the semantics of a more general dataset should be, while the TripletDataset is more standardized I'd say

3 - I think it might be good to specify a num_samples in the constructor, so that we know when to raise a StopIteration in __iter__. 4 - There is a PR in https://github.com/pytorch/pytorch/pull/18624 that adds support for choice. But I'd say that you should just start with using torch.multinomial for now, and do the indexing yourself.

Also, ccing @SsnL , as he was the one who originally designed IterableDataset, and might have some more insights on what would be good to have.

dakshjotwani commented 5 years ago

Sounds good. I look forward to what @SsnL has to say about our TripletDataset outline.

ssnl commented 5 years ago

This is a bit of an interesting use of IterableDataset. This would make the dataset unusable with any sampler, but it probably won't be too big of an issue. There is a lot of subtlety with using an IterableDataset with num_workers > 0, but if you only want to sample with replacement, then only taking care of the num_samples should be enough. That said, if this is going to be put in a general purpose library, e.g., torchvision, please be careful to make it work correctly. Feel free to tag me on the PR for review.

rwightman commented 5 years ago

Use cases for generic tuple datasets likely exist, but re motivation for the original task, I've never used an n-tuple dataset (n**k is not pleasant) with siamese or triplet loss fn. I've always used hard or semi-hard mining and more typical dataset that returns one element at a time but has a custom sampler to make sure each batch has the desired examples per identity.

I think most methods for identity recognition, face or otherwise since the FaceNet era have moved to online hard or semi-hard example mining for the triples and pairs. Or don't use them at all if they're using the various softmax + geometric loss like arcface, cosface, ring loss, etc etc.

Edit: This is a great blog post on the topic https://omoindrot.github.io/triplet-loss ... also, forgot my paper history, but FaceNet was actually the paper that introduced online selection and recommended not building a dataset that did offline triplet selection

fmassa commented 5 years ago

@rwightman yes, I agree that hard-negative mining is something that we should add support for.

But performing hard negative mining on a large dataset on-the-fly is very expensive computation-wise. So what is generally done is that the features are computed once every few epochs, and the hard-mining is done on those features.

This means that we can re-create the dataset once every few epochs or so, by updating the similarity matrix that is provided to it.

This would still kind of match with the approach in https://github.com/pytorch/vision/pull/1061 (but some changes would need to be done).

rwightman commented 5 years ago

@fmassa basing hard negative mining on features calculated every few epochs is still 'offline' ... by online I mean you select the pairings that you sum for your loss after you run the examples through the network on each batch and have a set of embeddings. It's less expensive than offline hard example mining. This is often called 'batch hard' or 'batch-semi hard' if you don't just pick the hardest (which can sometimes collapse if you have too many really hard examples).

This paper compares all methods: https://arxiv.org/abs/1703.07737 ... it compares offline Triplet, Triplet with OHM (offline hard mining which you just mentioned), with the batch-hard, batch-all

At some point in the training, you do end up with no or very few hard examples if you're using a hard margin (not so with soft margin), so the number of examples in each batch that you sum into the loss decreases, but that's often an indicator when to stop training....

I wouldn't recommend trying to fit the batch-hard/batch-all cases with the other n-tuple scenarios. The dataset for online bach schemes remains pretty normal like a typical classification setup, only the sampler that builds batches with K examples of P classes in each batch and the loss function, that computes pairwise distances and builds the tuples for each batch differs.

Tensorflow has an example loss fn, https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py#L160

dakshjotwani commented 5 years ago

Wouldn't this be something that would be dealt with within the training loop, after the forward pass has been taken care of? Maybe we could have a method that selects only the hard/semi-hard triplets to pass through the loss function?

rwightman commented 5 years ago

@dakshjotwani yeah, you need a non-trivial loss fn like that TF one I linked above. But note that it only takes a list of embeddings and a list of labels, it does not build any triplets.

I went through this exercise myself some time ago, was about to set out building a triplet dataset and then realized they are not needed and spent the time working on the loss fn/mining strategies and sampler.

So that's why bringing this up, do you actually want a triplet dataset? or do you want a triplet training scheme that works?

dakshjotwani commented 5 years ago

Interesting. I want to make sure I'm understanding this correctly, and listing some potential contributions we can make:

  1. Instead of building a dataset that returns triplets, we should build valid triplets after going through a forward pass.
  2. To make sure we get a sufficient/good amount of valid triplets in a batch, each batch will need to have a balanced number of identities/groups. For example, if we are training a face recognition model, for a batch size of size b, if we decide to have d identities in a batch, there must be at least floor(b / d) images for each identity. [Not sure if Pytorch has such a sampler. I think my code from TripletDataset #1061 can be modified to implement this.]
  3. Finally, once we go through the forward pass, we generate all possible valid triplets from that batch, and then employ a suitable mining strategy to speed-up training. [This could be a good addition to pytorch and/or torchvision, where we provide multiple triplet mining strategies. I can work on it.]
  4. Pass it through the loss function [Pytorch already has TripletHingeLoss, which should do the job]

@rwightman Is this the approach you want to take? Please correct me if I made any errors. @fmassa What do you think? Do you think this is a better approach to the problem? If we decide to change our approach, which modules would I have to work with? (Since it wouldn't be a dataset anymore)

In the meantime, I shall read more on this and get back with a gist as a proof of concept by tomorrow. It will probably help me understand the online mining approach better.

rwightman commented 5 years ago

@dakshjotwani Here is a gist of some bits and pieces I've used successfully in the past, it's certainly not at a level suitable for inclusion in torchvision though, but feel free to use for some ideas... https://gist.github.com/rwightman/fff86a015efddcba8b3c8008167ea705

The sampler there assumes the dataset has a map with a specific name in it. There would be better ways of making this interface more generic. Batch size needs to be set to p*k. In the loss I included a pretty standard hard-negative selection and my own hacky sampling selection. I wanted to revist my sampling approach someday but haven't had time. The idea was to sample hard negatives (and positives) more frequently than the easy ones, but not to always choose the hard ones as that can cause collapse in datasets where there are lots of very hard examples or noise in the labels. I'm sure someone smarter than I has come up a better approach for that, but my hack did work well on a dataset where the hard negative approach was causing the loss to collapse to the margin.

Overall aproach, I'd probably spend more time with all of these ideas, implement several different methods, see what works and what doesn't on a real dataset before commiting to one and finishing a PR.

dakshjotwani commented 5 years ago

Thanks for the gist! It'll surely help me understand this approach better.

The TripletDataset approach (Loading data as triplets and removing easy triplets after a forward pass) is what I used to train ResNet50 for face recognition. I used VGGFace2 as my dataset, and was getting 91% accuracy on my validation set (VGGFace2 validation, which has different identities), which is why I felt that it might be a good addition to torchvision.

I shall work with your approach on the same dataset and margin (margin=0.2) and get back to you with results.

fmassa commented 5 years ago

@dakshjotwani I've asked around and it seems that the online within-batch approach for triplet loss might be better. So I'll hold on on merging the TripletDataset for now, and wait until you get back with the results from your experimentation.

dakshjotwani commented 5 years ago

@rwightman This approach did work better than my previous approach. On VGGFace2's test set, using TripletDataset + semi hard mining resulted in 90.42% accuracy. The PKSampler + semi hard mining approach resulted in 93.05% accuracy (I'm currently still training it to see if it can do better). This is result is definitely outside margin of error. It also converged much faster as compared to my previous approach.

I attempted to refine your PKSampler gist a bit and rewrote https://github.com/omoindrot/tensorflow-triplet-loss/blob/master/model/triplet_loss.py in pytorch for my experimentation. The sampler and loss code I used for my experiment is here: https://github.com/dakshjotwani/pytorch-triplet-tools.

@fmassa I do believe, however, that there is still a need for some sort of TupleDataset. For testing/validation of my model, I iterated over all possible tuples (s0, s1) in the test set and labelled each tuple as 0 for label[s0] != label[s1] and 1 for label[s0] == label[s1]. Since I used this on my test set, n was not too large, and hence n ** 2 was manageable.

I think it would be quite nice to have some of these features (online triplet mining, PKSampler, TupleDataset) to make metric-learning more convenient using Pytorch/torchvision. What do you think?

fmassa commented 5 years ago

@dakshjotwani very interesting results! Thanks @rwightman for the heads up!

do believe, however, that there is still a need for some sort of TupleDataset

The dataset that you mention can be built on-the-fly, without materializing anything, right? Just compute on-the-fly the label for the tuple, and compute the index of s0 and s1 via idx // len(ds_orig) and idx % len(ds_orig).

I think it would be quite nice to have some of these features (online triplet mining, PKSampler, TupleDataset) to make metric-learning more convenient using Pytorch/torchvision

I agree, and I think I'd start with the triplet mining and PKSampler. I think this deserves a whole new folder in references/ in torchvision, with a full training / evaluation pipeline, which should still be fairly small.

What about the following:

Could you send a PR with a simple training / evaluation script that trains on VGGFace? This would mean:

Once those are there and merged, we will need to find a better place for adding those losses and samplers in the core torchvision.

Thoughts?

dakshjotwani commented 5 years ago

The dataset that you mention can be built on-the-fly, without materializing anything, right? Just compute on-the-fly the label for the tuple, and compute the index of s0 and s1 via idx // len(ds_orig) and idx % len(ds_orig).

@fmassa Yes it can. Currently, instead of doing that, I'm using itertools.combinations(range(n), k) to avoid swapped tuples ((a, b) and (b, a) are the same) and tuples with repeated entries ((a, a) for example).

Edit: I have made TupleDataset an IterableDataset that produces all tuple combinations and also works with multiple workers (does not rely on itertools anymore).

  • add a VGGFace dataset (with tests)

I could, but this would be quite complicated. Unlike the other datasets, the VGGFace2 download cannot be automated using wget for example, since you need to create an account and sign an EULA to download the dataset. It's also really large (36 GB).

Do you think we can work with something simpler, like MNIST for this PR? The blog post which @rwightman shared uses MNIST to demonstrate metric learning using triplet loss.

I feel like the VGGFace2 dataset deserves a different PR, which would also include data-preprocessing (cropping/aligning faces) and cleaning of the dataset. I can take it up right after finishing the the triplet mining strategy and PKSampler PR.

Here's what I think the PR should include:

Does this sound good?

rwightman commented 5 years ago

@dakshjotwani awesome, glad you got it working

I seen fixed Triplets used for validation as you said. I think at least one of the canonical datasets has validation triplets predefined by some common impl but I forget which.

Another common validation technique you'll see when the task is retrieval or reid is calculation of mAP and cmc metrics on a seperate or holdout set of index + query images. Picking two Pytorch based repos https://github.com/Cysu/open-reid/tree/master/reid and https://github.com/layumi/Person_reID_baseline_pytorch have this setup

For you training, if you're trying to sqeueze out a few more %, triplet loss often sensitive to h-params, so trying different optimizers, LR can have a big impact. Size of your embedding can make a difference. Also try with and without L2 normalizating your embeddings before feeding into the loss and with/without clamping (relu) the output to positive. With soft margin on and off, or different margins if the hard margin is working better.

dakshjotwani commented 5 years ago

I seen fixed Triplets used for validation as you said. I think at least one of the canonical datasets has validation triplets predefined by some common impl but I forget which.

Another common validation technique you'll see when the task is retrieval or reid is calculation of mAP and cmc metrics on a seperate or holdout set of index + query images. Picking two Pytorch based repos https://github.com/Cysu/open-reid/tree/master/reid and https://github.com/layumi/Person_reID_baseline_pytorch have this setup

Interesting. I'm actually working on another problem which is quite similar to the problems solved in the repos you mentioned. Thanks for the links! I'll delve deeper into them right away.

Currently my validation approach is to compute the distances between all tuples (k=2) in the test dataset and find a distance threshold that classifies (same or different face) the tuples correctly with the highest accuracy. This is the approach that was described in the Facenet paper. I hope that's an appropriate method to validate the model for face verification for now?

For you training, if you're trying to sqeueze out a few more %, triplet loss often sensitive to h-params, so trying different optimizers, LR can have a big impact. Size of your embedding can make a difference. Also try with and without L2 normalizating your embeddings before feeding into the loss and with/without clamping (relu) the output to positive. With soft margin on and off, or different margins if the hard margin is working better.

I'll keep that in mind and keep you posted! Thanks!

dakshjotwani commented 5 years ago

Closing this for now. We can reopen this if we want to implement TripletDataset in the future for validation purposes.