rpautrat / SuperPoint

Efficient neural feature detector and descriptor
MIT License
1.89k stars 417 forks source link

Descriptor loss train #287

Open ericzzj1989 opened 1 year ago

ericzzj1989 commented 1 year ago

Hi, first thank you for this great work which really helped me a lot! I want to use the Superpoint model to train on my own data. The detection loss part seems normal but the descriptor loss is oscillating and cannot converge. The input are 256*256 image and warped image with homography, and the model and loss function are the same with your repo. The detector and descriptor loss are as below with 300 epochs. Detector loss image Descriptor loss image

Can you give me some advice?

rpautrat commented 1 year ago

Hi, are you using exactly the code from this repo or did you plug in some parts of this repo into your own code? Are you using the same parameters as in the master branch?

ericzzj1989 commented 1 year ago

Thanks for your prompt reply. I use the paired data generation (exactly as coco.py), superpoint model and loss function code parts (reimplemented by PyTorch) of your repo. The data augmentation and model parameters in config file are the same with yours.

rpautrat commented 1 year ago

I see, then it might be a bit tricky for me to help you as it is a different code... It could be an implementation bug, or simply that your reimplementation might need different parameter tuning than this repo.

Note that there is also a Pytorch reimplementation of SuperPoint (partially based on this repo) that you might want to check out: https://github.com/eric-yyjau/pytorch-superpoint

ericzzj1989 commented 1 year ago

Thanks for your advice. I refer to the repo https://github.com/shaofengzeng/SuperPoint-Pytorch (based on your repo) and compare with yours. I personally think it is the same code and the same issue. And I also found that many issues mentioned the problem of descriptor loss. I played the parameters and RGB input, descriptor loss kept oscillating and could not converge. I have no idea where to solve this issue.

rpautrat commented 1 year ago

Tuning the descriptor loss was quite tricky in my case. But overall, training with a triplet loss is also rather tricky in general.

One thing that usually helps in my experience of training with triplet losses is to pre-train the network with a "relaxed" definition of the negative samples. In SuperPoint, given one cell at position (h, w), the corresponding cell (h', w') in the other image is used as positive anchor, while the other cells are considered as negative ones. But the descriptor of let's say pixel (h'+1, w') is also very close to the one of (h, w), thus forcing these two descriptors to be far apart is confusing the network (at least at the beginning of the training). So what you could do is to ignore the neighboring pixels of each positive cell in the descriptor loss, which is equivalent to making the negative samples less hard. Once the training has converged with this easier loss, you can fine-tune with the actual SuperPoint loss to get the best performance.

ericzzj1989 commented 1 year ago

Thanks for your experience. Is there any reference code you said? If so, would you please share it with me?

rpautrat commented 1 year ago

For SuperPoint, unfortunately not. I did not have to use this trick when I trained it.

But I had to use it for other works requiring a triplet loss, one example is here: https://github.com/mihaidusmanu/d2-net/blob/master/lib/loss.py. The loss is a bit different from the SuperPoint one though and is a triplet loss with hardest negative mining. But maybe you can get the idea and apply it to the SuperPoint loss. The safe_radius parameter is the one controlling how close to the positive anchor a negative can be. You should thus start training with a large safe_radius and fine-tune it with a lower one ideally.

ericzzj1989 commented 1 year ago

Thanks very much for your help. According to this https://github.com/rpautrat/SuperPoint/issues/164, I comment the following lines: https://github.com/rpautrat/SuperPoint/blob/1742343e7a9731929be00f7a940023b694098d6c/superpoint/models/utils.py#L110

    dot_product_desc = tf.nn.relu(dot_product_desc)
    dot_product_desc = tf.reshape(tf.nn.l2_normalize(
        tf.reshape(dot_product_desc, [batch_size, Hc, Wc, Hc * Wc]),
        3), [batch_size, Hc, Wc, Hc, Wc])
    dot_product_desc = tf.reshape(tf.nn.l2_normalize(
        tf.reshape(dot_product_desc, [batch_size, Hc * Wc, Hc, Wc]),
        1), [batch_size, Hc, Wc, Hc, Wc])

The train and val descriptor loss are as below with a small amount of data (100 samples). Train descriptor loss image Val descriptor loss image

These three line code is used for normalization of dot_descriptor and is it necessary for this operation after descriptors dot? I'm not sure if this is correct and maybe the amount of data (100 samples) is too small? Could you please give me some advice?

rpautrat commented 1 year ago

I am not sure to fully understand your last question, but these lines with the l2 normalization are a trick to make the correspondences between points more discriminative (i.e. that there is at most one correspodence rather than several similar candidates). The original SuperPoint did not have this trick, and the code should also work if you comment it. But I observed empirically better results with it personally.

On the graphs you show, there is a clear overfitting to the small training set, due to the small amount of samples.

ericzzj1989 commented 1 year ago

Thanks very much for your help. With this l2 normalization trick, the descriptor loss could not converge as the original issue graph shown. For overfitting, I will try to add more data for training and observe the loss.