rpautrat / SuperPoint

Efficient neural feature detector and descriptor
MIT License
1.91k stars 418 forks source link

Bugs in Hpatches evaluation framework #117

Closed Limzui closed 4 years ago

Limzui commented 4 years ago

Hi, firstly I would like to thank you for your extensive work on this repository.

I was training and evaluating my own Superpoint using this repo as a baseline and I noticed that descriptor evaluation scores on HPatches viewpoint changes are extremely low. This is even true for your own tests as well (0.244 as per your readme). I myself was getting homography correctness scores within the range of 0.02 to 0.2 for a distance threshold of 3. There is also a currently open issue on this as well: https://github.com/rpautrat/SuperPoint/issues/80. I went ahead to dig a little deeper as to why this might be happening, when I realized a few bugs in your framework.

These problems are specific to evaluating detectors and descriptors on HPatches dataset. For all illustrations in this issue report, I will be using HPatches v_artisans/1.ppm as image 1 and v_artisans/5.ppm as image 2. All paths mentioned will be relative to SuperPoint/superpoint directory.

The first problem that I found was that your adapted ground truth homography is incorrect, which I have root-caused as "Ignored crop". In your _adapt_homography_to_preprocessing() function in datasets/patches_dataset.py, you performed an upscale -> apply homography -> downscale to the ground truth homography to take into account the resized images, however, you did not take into account the central cropping of images after the aspect ratio-preserving resize. The result of that is that when you apply the adapted homography to the resized (and cropped) image 1s, you are actually warping the left-most portion of the images instead of warping the central crop. Here are the in-depth visualizations that I have performed:

Original image1, image2, and ground truth homography: Originals

Pre-processing of image1: Image1 preprocess

Pre-processing of image2: Image2 preprocess

Adapted homography applied on pre-processed image1: Original adapted homography

Adapted homography applied on pre-processed image1 (taking left crop instead of central crop): Original adapted homography 2

This bug will snowball when performing evaluations using evaluations/detector_evaluation.py or evaluations/descriptor_evaluation.py through ../notebooks/descriptors_evaluation_on_hpatches.ipynb or ../notebooks/detector_repeatability_hpatches.ipynb. The function keep_shared_points() will incorrectly filter out points that are actually shared, while keeping points that are not shared, like the example below: keep_shared_points with bug

Also, when calculating pairwise distances when evaluating detector repeatability or descriptor homography estimation, keypoints/corners being warped by the adapted homography do not end up where they are supposed to end up, hence introducing extremely large distance errors even for highly accurate predicted homographies: eyeball2 Corners and Mean Dist The image on the left is the ground truth image 2, and the image on the right is image 1 warped with the predicted homography from cv2.findHomography(). Just by eyeballing these 2, I expect the distance error of the 4 corners to be in the ballpark of <20 pixels. However, when I print out the actual distance error, I get 278 pixels, which is impossible. Furthermore, if you inspect the corner coordinates, the predicted corners make even more sense than the ground truth corners (There should be at least 1 negative number as the top right corner extends above the top of the frame).

I have fixed this bug by adding in a translation element into the _adapt_homography_to_preprocessing() function. Code is as below:

def _adapt_homography_to_preprocessing(zip_data):
    image = zip_data['image']
    H = tf.cast(zip_data['homography'], tf.float32)
    source_size = tf.cast(tf.shape(image)[:2], tf.float32)
    target_size = tf.cast(tf.convert_to_tensor(config['preprocessing']['resize']), tf.float32)
    s = tf.reduce_max(tf.divide(target_size, source_size))

    fit_height = tf.greater(tf.divide(target_size[0], source_size[0]), tf.divide(target_size[1], source_size[1]))

    padding_y = tf.to_int32(((source_size[0] * s - target_size[0]) / tf.constant(2.0)))
    padding_x = tf.to_int32(((source_size[1] * s - target_size[1]) / tf.constant(2.0)))

    tx = tf.cond(fit_height, lambda: padding_x, lambda: tf.constant(0))
    ty = tf.cond(fit_height, lambda: tf.constant(0), lambda: padding_y)
    translation = tf.stack([tf.constant(1), tf.constant(0), tx, 
                            tf.constant(0), tf.constant(1), ty,
                            tf.constant(0),tf.constant(0), tf.constant(1)])
    translation = tf.to_float(tf.reshape(translation, [3,3]))

    down_scale = tf.diag(tf.stack([1/s, 1/s, tf.constant(1.)]))
    up_scale = tf.diag(tf.stack([s, s, tf.constant(1.)]))
    H = up_scale @ H @ down_scale @ translation
    return H

Result: Rectified adapted homography

After fixing this bug, I noticed that warping image1 is still not lining up with image2 perfectly: Superimposed

This is due to a second bug, which I have root-caused as "Scale mismatch". If you refer to the images above showing the pre-processing of image1 and image2, you will notice that they have different scales. This is due to the scale being calculated individually on images, and also because images from the same set (e.g. v_artisans) do not all share the same source resolution. Thus, an image (and therefore, coordinates) resized by a scale factor of ~0.63 cannot be directly compared to another which is resized by a different scale factor of ~0.73.

In order to fix this, I had to edit the _get_data() function from datasets/patches_dataset.py to make sure the warped images are scaled by the same factor that the unwarped image was scaled by. For brevity, I have only included relevant sub-function definitions:

def _get_data(self, files, split_name, **config):
   '''Excluded functions'''

    def _get_shapes(image):
        return tf.shape(image)[:2]

    def _get_scales(shapes):
        return tf.reduce_max(tf.cast(tf.divide(tf.convert_to_tensor(config['preprocessing']['resize'], dtype=tf.float32), tf.to_float(shapes)), tf.float32))

    images = tf.data.Dataset.from_tensor_slices(files['image_paths'])
    images = images.map(lambda path: tf.py_func(_read_image, [path], tf.uint8))
    homographies = tf.data.Dataset.from_tensor_slices(np.array(files['homography']))
    warped_images = tf.data.Dataset.from_tensor_slices(files['warped_image_paths'])
    warped_images = warped_images.map(lambda path: tf.py_func(_read_image,
                                                              [path],
                                                              tf.uint8))       
    if config['preprocessing']['resize']:
        homographies = tf.data.Dataset.zip({'image': images,
                                            'homography': homographies})
        homographies = homographies.map(_adapt_homography_to_preprocessing)
        shape1s = images.map(_get_shapes)
        scales = shape1s.map(_get_scales)  
        warped_images = warped_images.map(lambda x: _preprocess_warped(x, scales)) 
    else:
        warped_images = warped_images.map(_preprocess)
    images = images.map(_preprocess)
    data = tf.data.Dataset.zip({'image': images, 'warped_image': warped_images,
                                'homography': homographies})

    return data

I understand that this will cause image2 to not fit the target_size, but that is of absolutely no importance. Once image2 has been appropriately scaled, the 2 will superimpose perfectly: Correct scale resized image2

After applying these 2 fixes to the code, I was still getting large distance errors for descriptor evaluation. As I am very sure that the homography is completely fixed now, I managed to narrow the problem down to the corner distance calculation in compute_homography() function in evaluations/descriptor_evaluation.py.

Original code snippet:

corners = np.array([[0, 0, 1],
                    [0, shape[1] - 1, 1],
                    [shape[0] - 1, 0, 1],
                    [shape[0] - 1, shape[1] - 1, 1]])
real_warped_corners = np.dot(corners, np.transpose(real_H))
real_warped_corners = real_warped_corners[:, :2] / real_warped_corners[:, 2:]
warped_corners = np.dot(corners, np.transpose(H))
warped_corners = warped_corners[:, :2] / warped_corners[:, 2:]
mean_dist = np.mean(np.linalg.norm(real_warped_corners - warped_corners, axis=1))

What's wrong with this is that the corners are defined wrongly. Homogeneous coordinates format should be (x, y, 1), however the corners are defined as (y, x, 1). This will cause the matrix multiplication with the homography to warp the coordinates as if it's a vertical image rather than a horizontal image (or vice versa). This is the visualization of the real_warped_corners using this code: Corners

Just by swapping the axes in the definition of corners as so:

corners = np.array([[0, 0, 1],
                    [shape[1] - 1, 0, 1],
                    [0, shape[0] - 1, 1],
                    [shape[1] - 1, shape[0] - 1, 1]])

And visualizing the new real_warped_corners again, I get this: Rectified corners

Only after fixing these 3 bugs do I start to get more sensible evaluation results. Thank you for your time in reading this lengthy issue report and once again, thank you for the work you have done on this repository. Feel free to contact me should you have any queries on anything that I have brought up. I will also provide a notebook that includes how I got all the visualization screenshots I have attached in this issue report so that anyone can reproduce them: https://github.com/Limzui/Superpoint-bug-notebook

Vincentqyw commented 4 years ago

@Limzui Thanks for your share, I am also curious about the low correctness. BTW do you have any results after fixing these bugs, e.g. some comparison on keypoint repeatability and desc evaluation on hpatches.

rpautrat commented 4 years ago

Hi @Limzui,

Wow, thanks a lot for the detailed issue and with the corresponding notebook! I will definitely have a look at this as soon as possible, but given the length of the issue, please give some time to review all this :)

Limzui commented 4 years ago

@Limzui Thanks for your share, I am also curious about the low correctness. BTW do you have any results after fixing these bugs, e.g. some comparison on keypoint repeatability and desc evaluation on hpatches.

Hi @Vincentqyw, sorry for the delayed reply, it took me some time to re-export and evaluate everything. Here are the before/after results. I am using MagicLeap pre-trained checkpoint for all these experiments for comparability sake.

Detector before vs after visualization: Before and After

Descriptor old visualization: Descriptor old

Descriptor new visualization: Descriptor new

Detector Viewpoint: Detector V Detector Illumination: Detector I

Descriptor Viewpoint: Descriptor V Descriptor Illumination: Descriptor I

All evaluations are done with @rpautrat's settings. Detector: resize 240x320, nms 4, keep 300 points, threshold 3 pixels, confidence 0.015 Descriptor: resize 480x640, nms 8, keep 1000 points, threshold 3 pixels

zpfriedel commented 4 years ago

@Limzui For your second bug fix, I see a function called _preprocess_warped() that is not in the original implementation. Do you mind posting that here?

rpautrat commented 4 years ago

@Limzui, I had a look at your fixes, incorporated them in the framework and tested them with similar results as you. Thank you very much again for your contribution, the evaluation makes much more sense now!! You can find the new changes on branch superpoint_v1 (to be merged with master later).

@zpfriedel, you can find my version of _preprocessed_warped() in this commit: eccbd5e0bfcec343ad2ae90ac7eca25e551b434f

Limzui commented 4 years ago

@zpfriedel my apologies for missing that out!

My implementation is as follows:

def _preprocess_warped(image, scale):
        tf.Tensor.set_shape(image, [None, None, 3])
        image = tf.image.rgb_to_grayscale(image)
        iterator = scale.make_one_shot_iterator()
        next_scale = iterator.get_next()
        new_size = tf.to_int32(tf.to_float(tf.shape(image)[:2]) * next_scale)
        image = tf.image.resize_images(image, tf.to_int32(new_size),
                            method=tf.image.ResizeMethod.BILINEAR)
        return tf.to_float(image)
Limzui commented 4 years ago

@rpautrat I took a look at the superpoint_v1 branch and I noticed that you have updated the descriptor evaluation numbers but not the detector. I am not sure if you have yet to update those or that you missed them out. Those need to be updated too as the new exported HPatches will affect those results as well (keep_shared_points, homographic warping). Just a heads up!

rpautrat commented 4 years ago

Yes, the new evaluation of repeatability is coming too, I was using the night to run all the experiments.

zpfriedel commented 4 years ago

@rpautrat A little off topic, but I was curious about what you think about applying the valid_mask to the predictions in the precision and recall metrics. I know they aren't something we should wholly rely on to judge the performance of the training, but maybe applying the mask to ignore any detections 'outside of the image' would make the metrics at least a tiny bit more reliable? Just a thought.

rpautrat commented 4 years ago

Yes, that's a good idea @zpfriedel! I will add it to the new branch.

zpfriedel commented 4 years ago

@rpautrat I was looking at your changes to homography_adaptation and noticed I was still getting detections on the border when I was exporting labels. I modified the step function slightly and they border detections went away! Code is posted below, let me know what you think. All I really did was add a mask variable to crop the border instead of using the count variable. Other changes work really well, thank you!

def step(i, probs, counts, images):
        # Sample image patch
        H = sample_homography(shape, **config['homographies'])
        H_inv = invert_homography(H)
        warped = H_transform(image, H, interpolation='BILINEAR')
        count = H_transform(tf.expand_dims(tf.ones(tf.shape(image)[:3]), -1),
                            H_inv, interpolation='NEAREST')[..., 0]
        mask = H_transform(tf.expand_dims(tf.ones(tf.shape(image)[:3]), -1),
                            H, interpolation='NEAREST')
        # Ignore the detections too close to the border to avoid artifacts
        if config['valid_border_margin']:
            kernel = cv.getStructuringElement(
                cv.MORPH_ELLIPSE, (config['valid_border_margin'] * 2,) * 2)
            with tf.device('/cpu:0'):
                mask = tf.nn.erosion2d(
                    mask, tf.to_float(tf.constant(kernel)[..., tf.newaxis]),
                    [1, 1, 1, 1], [1, 1, 1, 1], 'SAME')[..., 0] + 1.

        # Predict detection probabilities
        prob = net(warped)['prob']
        prob = prob * mask
        prob_proj = H_transform(tf.expand_dims(prob, -1), H_inv,
                                interpolation='BILINEAR')[..., 0]
        prob_proj = prob_proj * count
Vincentqyw commented 4 years ago

@Limzui , thanks for your answer, there are a lot impovement on evaluation viewpoint changes. Another question, after you fix these bugs, do you re-export the coco pairs and re-train superpoint on the new re-exported coco pairs? or just re-evaluate superpoint?

Limzui commented 4 years ago

Hi @Vincentqyw, there should be no need to re-export coco pairs as these bugs are specific only to HPatches evaluation (export/evaluation of HPatches). Thus no re-training is needed as well. Just need to re-export HPatches detectors/descriptors and re-evaluate using the notebooks. :)

Vincentqyw commented 4 years ago

many thanks for your prompt response:) @Limzui

zpfriedel commented 4 years ago

@rpautrat I was curious to see if were able verify my previous comment in this thread?

rpautrat commented 4 years ago

I was a bit busy recently, but I still tried with the current version and your own. I couldn't find an example with detections on the border with the current actually. Do you have an image in particular where this behavior is visible?

In theory the two versions should produce equivalent results and I verified it on simple test cases (they had exactly the same behavior). But it with interpolation and approximations, you never know...

zpfriedel commented 4 years ago

@rpautrat That's strange that you get the exact same results and I don't. The first image is using my modification and the second image is using your code. The detections within the image look the exact same, there's just the border detections in the second one.

Figure_2 Figure_6

rpautrat commented 4 years ago

Thanks for the images @zpfriedel. It seems that those images have a thin black border on the sides, hence any line from the image crossing the border will generate a corner and thus a keypoint. So technically, the detections on the side are not wrong.

But I agree that it is not the kind of behavior that we would like to observe. I understood the difference between the current approach and yours and both remove detections on the border but in different situations. So I will add your changes to the most recent branch, but also keep the current erosion on count too, which filters out some border detections as well.

Thanks for your contribution!

zpfriedel commented 4 years ago

@rpautrat This occurred on every image I looked at (20+ images) even if there wasn't a thin black border. But you're right, utilizing both methods should stop it all together!

rpautrat commented 4 years ago

I am closing this issue now as every improvement has been added to the new branch (I will push the mask in homography adaptation with a new evaluation later).

Feel free to open a new issue in case you have other suggestions for improvement!