facebookresearch / silk

SiLK (Simple Learned Keypoint) is a self-supervised deep learning keypoint model.
GNU General Public License v3.0
643 stars 58 forks source link

How to convert sparse SiLK to torch script ? #16

Closed Pepper-FlavoredChewingGum closed 1 year ago

Pepper-FlavoredChewingGum commented 1 year ago

Thank you for your awesome works.

I noticed you provided an example of exporting the dense silk model as torch script. However, when I changed the setup to use sparse feature points and sparse descriptors (as in silk-inference.py), I ran across two issues.

First, because there are many ”if“ in the code, the system throws numerous warnings when I use the torch.jit.trace method, such as ”TracerWarning: Converting a tensor to a Python boolean may result in an inaccurate trace.“ Is it possible to disregard this warning? How does this affect the exported script model?

Second, the model output must be transformed using the "from_feature_coords_to_image_coords (...)" method to get the actual feature points in image coordinates. However, once exporting the model as a script file, it appears that there is no ability to do so again. What need I do so that the exported script file can output the feature points in the image coordinates directly? Because I want to use this model in other environments.

Thank you so much for your recommendations, and I eagerly await your response.

gleize commented 1 year ago

Hi @Pepper-FlavoredChewingGum,

First, because there are many ”if“ in the code, the system throws numerous warnings when I use the torch.jit.trace method, such as ”TracerWarning: Converting a tensor to a Python boolean may result in an inaccurate trace.“ Is it possible to disregard this warning? How does this affect the exported script model?

Those warnings can essentially be disregarded, but there is a catch. It essentially involves three parameters :

  1. The input channel size.
  2. The keypoint score threshold value.
  3. The top-k value.

1. isn't a problem and won't change. 2. and 3. (set here) means that they will be frozen in the torchscript as constant (i.e. it's not possible to run the torchscript with different values of threshold and top-k).

Any change in those parameters would require to re-built the torchscript again. In most cases that's acceptable.

Second, the model output must be transformed using the "from_feature_coords_to_image_coords (...)" method to get the actual feature points in image coordinates. However, once exporting the model as a script file, it appears that there is no ability to do so again. What need I do so that the exported script file can output the feature points in the image coordinates directly? Because I want to use this model in other environments.

That's actually an error. Thanks for pointing that out. The fix will be out soon.

In the meantime, you can try that :

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
from common import get_model, load_images
from silk.backbones.silk.silk import from_feature_coords_to_image_coords

IMAGE_0_PATH = "/datasets01/hpatches/01042022/v_adam/1.ppm"
IMAGE_1_PATH = "/datasets01/hpatches/01042022/v_adam/2.ppm"
OUTPUT_MODEL = "script_model.pt"

def test_on_image_pair(model, script_model, images):
    # run model
    positions_0, descriptors_0, _ = model(images)
    positions_1, descriptors_1, _ = script_model(images)

    # check result consistency
    assert len(positions_0) == len(positions_1) == 2
    assert torch.allclose(positions_0[0], positions_1[0])
    assert torch.allclose(positions_0[1], positions_1[1])

    assert len(descriptors_0) == len(descriptors_1) == 2
    assert torch.allclose(descriptors_0[0], descriptors_1[0])
    assert torch.allclose(descriptors_0[1], descriptors_1[1])

def model_with_corrected_positions(model):
    def fn(images):
        results = model(images)
        assert type(results) is tuple
        positions = results[0] # IMPORTANT : only works when positions are in first place 
        positions = from_feature_coords_to_image_coords(model, positions)
        return (positions,) + results[1:]
    return fn

def main():
    # load image
    images = load_images(IMAGE_0_PATH, IMAGE_1_PATH)

    # load model
    model = get_model(default_outputs=("sparse_positions", "sparse_descriptors", "probability"))
    model = model_with_corrected_positions(model)

    # trace model to torch script
    script_model = torch.jit.trace(model, images)

    # save model to disk
    torch.jit.save(script_model, OUTPUT_MODEL)
    # load model from disk (to test it below)
    script_model = torch.jit.load(OUTPUT_MODEL)

    # test on same size images
    test_on_image_pair(model, script_model, images)

    # test on downsampled images (to check shapes are not frozen during tracing)
    downsampled_images = torch.nn.functional.interpolate(images, scale_factor=0.5)
    test_on_image_pair(model, script_model, downsampled_images)

    print(f'torch script model "{OUTPUT_MODEL}" created and tested')
    print("done")

if __name__ == "__main__":
    main()
Pepper-FlavoredChewingGum commented 1 year ago

Thank you very much for your prompt response, which was quite helpful.

I ran into some new issues. I hope to use the SiLK model for SLAM, which demands high real-time performance. However, if the top_k is set to a large value, the computational burden of the algorithm will be extremely high. As a result, I can only set top_k to 1000, which reduces matching performance. Because of this issue, the SLAM algorithm is prone to failure.

Could you please tell me what suggestions you have for this? I eagerly await your response.

gleize commented 1 year ago

Hi @Pepper-FlavoredChewingGum,