verlab / accelerated_features

Implementation of XFeat (CVPR 2024). Do you need robust and fast local feature extraction? You are in the right place!
https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24
Apache License 2.0
1.02k stars 113 forks source link

Add support for exporting to ONNX. #5

Open acai66 opened 7 months ago

acai66 commented 7 months ago
  1. Support for exporting xfeat, xfeat+matching models.
  2. Support for exporting dynamic shapes.
  3. Support for exporting a specified version of opset.
  4. Add onnxruntime inference demo.
Model Inputs Outputs Note
xfeat.onnx images feats, keypoints, heatmaps Extract image keypoints and features
xfeat_dualscale.onnx images mkpts, feats, sc Extract dualscale image keypoints and features.
matching.onnx mkpts0, feats0, sc0, mkpts1, feats1 matches, batch_indexes Match dualscale keypoints and features
xfeat_matching.onnx images0, images1 matches, batch_indexes End-to-end extraction of features from two sets of images and performing feature matching.

onnx_models.zip

examples

image

xfeat_onnxruntime.py

import numpy as np
import onnxruntime as ort

def create_ort_session(model_path, trt_engine_cache_path='trt_engine_cache', trt_engine_cache_prefix='model'):
    tmp_ort_session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])

    # print the input,output names and shapes
    for i in range(len(tmp_ort_session.get_inputs())):
        print(f"Input name: {tmp_ort_session.get_inputs()[i].name}, shape: {tmp_ort_session.get_inputs()[i].shape}")
    for i in range(len(tmp_ort_session.get_outputs())):
        print(f"Output name: {tmp_ort_session.get_outputs()[i].name}, shape: {tmp_ort_session.get_outputs()[i].shape}")

    providers = [
        # The TensorrtExecutionProvider is the fastest.
        ('TensorrtExecutionProvider', { 
            'device_id': 0,
            'trt_max_workspace_size': 4 * 1024 * 1024 * 1024,
            'trt_fp16_enable': True,
            'trt_engine_cache_enable': True,
            'trt_engine_cache_path': trt_engine_cache_path,
            'trt_engine_cache_prefix': trt_engine_cache_prefix,
            'trt_dump_subgraphs': False,
            'trt_timing_cache_enable': True,
            'trt_timing_cache_path': trt_engine_cache_path,
            #'trt_builder_optimization_level': 3,
        }),

        # The CUDAExecutionProvider is slower than PyTorch, 
        # possibly due to performance issues with large matrix multiplication "cossim = torch.bmm(feats1, feats2.permute(0,2,1))"
        # Reducing the top_k value when exporting to ONNX can decrease the matrix size.
        ('CUDAExecutionProvider', { 
            'device_id': 0,
            'gpu_mem_limit': 4 * 1024 * 1024 * 1024,
        }),
        ('CPUExecutionProvider',{ 
        })
    ]
    ort_session = ort.InferenceSession(model_path, providers=providers)

    return ort_session

class XFeat:

    def __init__(self, xfeat_model_path='./xfeat_dualscale.onnx', matcher_model_path='./matching.onnx'):
        self.xfeat_ort_session = create_ort_session(xfeat_model_path, trt_engine_cache_prefix='xfeat_dualscale')
        self.matcher_ort_session = create_ort_session(matcher_model_path, trt_engine_cache_prefix='matching')

        # warm up
        for i in range(5):
            image = np.zeros((640, 640, 3), dtype=np.float32)
            self.detectAndCompute(image)
            mkpts0 = np.zeros((1, 4800, 2), dtype=np.float32)
            feats0 = np.zeros((1, 4800, 64), dtype=np.float32)
            sc0 = np.zeros((1, 4800), dtype=np.float32)
            mkpts1 = np.zeros((1, 4800, 2), dtype=np.float32)
            feats1 = np.zeros((1, 4800, 64), dtype=np.float32)
            self.match(mkpts0, feats0, sc0, mkpts1, feats1)

    def detectAndCompute(self, image_data, mask=None):
        input_array = np.expand_dims(image_data.transpose((2, 0, 1)) , axis=0).astype(np.float32)
        inputs = {
            self.xfeat_ort_session.get_inputs()[0].name: input_array,
        }
        mkpts0, feats0, sc = self.xfeat_ort_session.run(None, inputs)

        return {
            "keypoints": mkpts0,
            "descriptors": feats0,
            "sc": sc,
        }

    def match(self, mkpts0, feats0, sc0, mkpts1, feats1):
        inputs = {
            self.matcher_ort_session.get_inputs()[0].name: mkpts0,
            self.matcher_ort_session.get_inputs()[1].name: feats0,
            self.matcher_ort_session.get_inputs()[2].name: sc0,
            self.matcher_ort_session.get_inputs()[3].name: mkpts1,
            self.matcher_ort_session.get_inputs()[4].name: feats1,
        }
        matches, batch_indexes = self.matcher_ort_session.run(None, inputs)

        return matches, batch_indexes

realtime_demo.py patches

diffs

diff --git a/realtime_demo.py b/realtime_demo.py
index 6c867fd..8e21ae2 100644
--- a/realtime_demo.py
+++ b/realtime_demo.py
@@ -7,20 +7,18 @@

 import cv2
 import numpy as np
-import torch

 from time import time, sleep
 import argparse, sys, tqdm
 import threading

-from modules.xfeat import XFeat

 def argparser():
     parser = argparse.ArgumentParser(description="Configurations for the real-time matching demo.")
     parser.add_argument('--width', type=int, default=640, help='Width of the video capture stream.')
     parser.add_argument('--height', type=int, default=480, help='Height of the video capture stream.')
     parser.add_argument('--max_kpts', type=int, default=3_000, help='Maximum number of keypoints.')
-    parser.add_argument('--method', type=str, choices=['ORB', 'SIFT', 'XFeat'], default='XFeat', help='Local feature detection method to use.')
+    parser.add_argument('--method', type=str, choices=['ORB', 'SIFT', 'XFeat', 'XFeat_Ort'], default='XFeat', help='Local feature detection method to use.')
     parser.add_argument('--cam', type=int, default=0, help='Webcam device number.')
     return parser.parse_args()

@@ -52,6 +50,7 @@ class CVWrapper():
     def __init__(self, mtd):
         self.mtd = mtd
     def detectAndCompute(self, x, mask=None):
+        import torch
         return self.mtd.detectAndCompute(torch.tensor(x).permute(2,0,1).float()[None])[0]

 class Method:
@@ -65,7 +64,12 @@ def init_method(method, max_kpts):
     elif method == "SIFT":
         return Method(descriptor=cv2.SIFT_create(max_kpts, contrastThreshold=-1, edgeThreshold=1000), matcher=cv2.BFMatcher(cv2.NORM_L2, crossCheck=True))
     elif method == "XFeat":
+        from modules.xfeat import XFeat
         return Method(descriptor=CVWrapper(XFeat(top_k = max_kpts)), matcher=XFeat())
+    elif method == "XFeat_Ort":
+        from xfeat_onnxruntime import XFeat
+        xfeat = XFeat()
+        return Method(descriptor=xfeat, matcher=xfeat)
     else:
         raise RuntimeError("Invalid Method.")

@@ -200,6 +204,12 @@ class MatchingDemo:
         if self.args.method in ['SIFT', 'ORB']:
             kp1, des1 = self.ref_precomp
             kp2, des2 = self.method.descriptor.detectAndCompute(current_frame, None)
+        elif self.args.method in ['XFeat_Ort']:
+            current = self.method.descriptor.detectAndCompute(current_frame)
+            kpts1, descs1, sc1 = self.ref_precomp['keypoints'], self.ref_precomp['descriptors'], self.ref_precomp['sc']
+            kpts2, descs2 = current['keypoints'], current['descriptors']
+            matches, batch_indexes = self.method.matcher.match(kpts1, descs1, sc1, kpts2, descs2)
+            points1, points2 = matches[batch_indexes == 0][..., :2], matches[batch_indexes == 0][..., 2:]
         else:
             current = self.method.descriptor.detectAndCompute(current_frame)
             kpts1, descs1 = self.ref_precomp['keypoints'], self.ref_precomp['descriptors']
IamShubhamGupto commented 6 months ago

Im slightly busy with graduation this week, I will be back to work on this by the weekend

acai66 commented 6 months ago

Im slightly busy with graduation this week, I will be back to work on this by the weekend

No problem, thanks for letting me know. Congratulations on your graduation!

acai66 commented 5 months ago

Here is a sample python code for running xfeat_matching.onnx with TensorRT API. This example does not include exception handling, input/output validation, etc., and is for reference only. image

xfeat_tensorrt_python.zip

update 2024-06-24:

import os
import sys
import time
from typing import Optional, List
from functools import reduce

import numpy as np
import cv2
import tensorrt as trt
import cupy as cp    # pip install cupy-cuda12x
from cuda import cuda, cudart    # pip install cuda-python
from packaging.version import Version

if Version(trt.__version__) >= Version('9.0.0'):
    # This is a simple ASCII-art progress monitor comparable to the C++ version in sample_progress_monitor.
    class SimpleProgressMonitor(trt.IProgressMonitor):
        def __init__(self):
            trt.IProgressMonitor.__init__(self)
            self._active_phases = {}
            self._step_result = True

        def phase_start(self, phase_name, parent_phase, num_steps):
            try:
                if parent_phase is not None:
                    nbIndents = 1 + self._active_phases[parent_phase]['nbIndents']
                else:
                    nbIndents = 0
                self._active_phases[phase_name] = { 'title': phase_name, 'steps': 0, 'num_steps': num_steps, 'nbIndents': nbIndents }
                self._redraw()
            except KeyboardInterrupt:
                # The phase_start callback cannot directly cancel the build, so request the cancellation from within step_complete.
                _step_result = False

        def phase_finish(self, phase_name):
            try:
                del self._active_phases[phase_name]
                self._redraw(blank_lines=1) # Clear the removed phase.
            except KeyboardInterrupt:
                _step_result = False

        def step_complete(self, phase_name, step):
            try:
                self._active_phases[phase_name]['steps'] = step
                self._redraw()
                return self._step_result
            except KeyboardInterrupt:
                # There is no need to propagate this exception to TensorRT. We can simply cancel the build.
                return False

        def _redraw(self, *, blank_lines=0):
            # The Python curses module is not widely available on Windows platforms.
            # Instead, this function uses raw terminal escape sequences. See the sample documentation for references.
            def clear_line():
                print('\x1B[2K', end='')
            def move_to_start_of_line():
                print('\x1B[0G', end='')
            def move_cursor_up(lines):
                print('\x1B[{}A'.format(lines), end='')

            def progress_bar(steps, num_steps):
                INNER_WIDTH = 10
                completed_bar_chars = int(INNER_WIDTH * steps / float(num_steps))
                return '[{}{}]'.format(
                    '=' * completed_bar_chars,
                    '-' * (INNER_WIDTH - completed_bar_chars))

            # Set max_cols to a default of 200 if not run in interactive mode.
            max_cols = os.get_terminal_size().columns if sys.stdout.isatty() else 200

            move_to_start_of_line()
            for phase in self._active_phases.values():
                phase_prefix = '{indent}{bar} {title}'.format(
                    indent = ' ' * phase['nbIndents'],
                    bar = progress_bar(phase['steps'], phase['num_steps']),
                    title = phase['title'])
                phase_suffix = '{steps}/{num_steps}'.format(**phase)
                allowable_prefix_chars = max_cols - len(phase_suffix) - 2
                if allowable_prefix_chars < len(phase_prefix):
                    phase_prefix = phase_prefix[0:allowable_prefix_chars-3] + '...'
                clear_line()
                print(phase_prefix, phase_suffix)
            for line in range(blank_lines):
                clear_line()
                print()
            move_cursor_up(len(self._active_phases) + blank_lines)
            sys.stdout.flush()

FP16_ENABLE = True
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

def check_cuda_err(err):
    if isinstance(err, cuda.CUresult):
        if err != cuda.CUresult.CUDA_SUCCESS:
            raise RuntimeError("Cuda Error: {}".format(err))
    if isinstance(err, cudart.cudaError_t):
        if err != cudart.cudaError_t.cudaSuccess:
            raise RuntimeError("Cuda Runtime Error: {}".format(err))
    else:
        raise RuntimeError("Unknown error type: {}".format(err))

def cuda_call(call):
    err, res = call[0], call[1:]
    check_cuda_err(err)
    if len(res) == 1:
        res = res[0]
    return res

class OutputAllocator(trt.IOutputAllocator):
    def __init__(self):
        trt.IOutputAllocator.__init__(self)
        self.buffers = {}
        self.shapes = {}

    def reallocate_output(self, tensor_name, memory, size, alignment):
        output_dtype = cp.dtype(cp.byte)
        output = cp.empty(size, output_dtype)
        ptr = output.data.ptr
        self.buffers[tensor_name] = output
        return ptr

    def notify_shape(self, tensor_name, shape):
        self.shapes[tensor_name] = tuple(shape)

def get_engine(onnx_file_path, engine_file_path=""):
    """Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""

    TRT_LOGGER = trt.Logger()

    def build_engine():
        """Takes an ONNX file and creates a TensorRT engine to run inference with"""
        with trt.Builder(TRT_LOGGER) as builder, builder.create_network(
            0 if Version(trt.__version__) >= Version('9.0.0') else EXPLICIT_BATCH
        ) as network, builder.create_builder_config() as config, trt.OnnxParser(
            network, TRT_LOGGER
        ) as parser, trt.Runtime(
            TRT_LOGGER
        ) as runtime:
            # Parse model file
            if not os.path.exists(onnx_file_path):
                print(
                    "ONNX file {} not found.".format(onnx_file_path)
                )
                return None
            print("Loading ONNX file from path {}...".format(onnx_file_path))
            with open(onnx_file_path, "rb") as model:
                print("Beginning ONNX file parsing")
                if not parser.parse(model.read()):
                    print("ERROR: Failed to parse the ONNX file.")
                    for error in range(parser.num_errors):
                        print(parser.get_error(error))
                    return None

            print("Completed parsing of ONNX file")

            if Version(trt.__version__) >= Version('9.0.0'):
                config.progress_monitor = SimpleProgressMonitor()
            config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 << 30) # 4GB
            # config.max_workspace_size = 1 << 30  # 1GB tensorrt < 8.4
            # builder.max_batch_size = 1

            for i in range(network.num_inputs):
                input_node = network.get_input(i)
                print('input ', i, input_node.name, input_node.shape)
            assert network.num_inputs == 2, "For xfeat_matching.onnx, only supports two input nodes."

            profile = builder.create_optimization_profile()
            profile.set_shape(network.get_input(0).name, (1, 3, 64, 64), (8, 3, 640, 640), (32, 3, 1280, 1280))
            profile.set_shape(network.get_input(1).name, (1, 3, 64, 64), (8, 3, 640, 640), (32, 3, 1280, 1280))
            config.add_optimization_profile(profile)

            if builder.platform_has_fast_fp16 and FP16_ENABLE:
                print("FP16 mode enabled")
                config.set_flag(trt.BuilderFlag.FP16)

            print("Building an engine from file {}; this may take a while...".format(onnx_file_path))
            plan = builder.build_serialized_network(network, config)
            engine = runtime.deserialize_cuda_engine(plan)
            print("Completed creating Engine")
            with open(engine_file_path, "wb") as f:
                f.write(plan)
            return engine

    if os.path.exists(engine_file_path):
        # If a serialized engine exists, use it instead of building an engine.
        print("Reading engine from file {}".format(engine_file_path))
        try:
            with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
                engine = runtime.deserialize_cuda_engine(f.read())
                if engine == None:
                    print("Deserialization of the engine from {} failed. Falling back to building the engine".format(engine_file_path))
                    return build_engine()
                return engine
        except Exception as e:
            print(e)
            print("Deserialization of the engine from {} failed. Falling back to building the engine".format(engine_file_path))
            return build_engine()
    else:
        return build_engine()

# This function is generalized for multiple inputs/outputs for full dimension networks.
# inputs and outputs are expected to be lists of HostDeviceMem objects.
def do_inference(context, engine, inputs_dict, output_allocator, stream):
    num_io = engine.num_io_tensors
    outputs = []
    for i in range(num_io):
        name = engine.get_tensor_name(i)
        if engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
            context.set_tensor_address(name, inputs_dict[name].data.ptr)
            context.set_input_shape(name, inputs_dict[name].shape)
    context.execute_async_v3(stream_handle=stream)
    cuda_call(cudart.cudaStreamSynchronize(stream))

    for i in range(num_io):
        name = engine.get_tensor_name(i)
        if engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
            output_shape = output_allocator.shapes[name]
            output_dtype = cp.dtype(trt.nptype(engine.get_tensor_dtype(name)))
            output_size = output_dtype.itemsize * reduce((lambda x, y: x * y), output_shape, 1)
            outputs.append(output_allocator.buffers[name][:output_size].view(output_dtype).reshape(output_shape))

    return outputs

class XFEAT:

    def __init__(self, modelPath, gpu_id=0) -> None:

        self.gpu_id = gpu_id

        # Init use gpu_id
        cudart.cudaSetDevice(self.gpu_id)

        engine_file_path = os.path.splitext(modelPath)[0] + ".trt"
        self.engine = get_engine(modelPath, engine_file_path)
        self.context = self.engine.create_execution_context()
        self.stream = cuda_call(cudart.cudaStreamCreate())

        self.output_allocator = OutputAllocator()
        for i in range(self.engine.num_io_tensors):
            name = self.engine.get_tensor_name(i)
            if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
                self.context.set_output_allocator(name, self.output_allocator)

        tensor_names = [self.engine.get_tensor_name(i) for i in range(self.engine.num_io_tensors)]
        for tensor_name in tensor_names:
            print(tensor_name, self.engine.get_tensor_shape(tensor_name))

    def inference_preprocessed(self, input_array_0, input_array_1):
        cudart.cudaSetDevice(self.gpu_id)

        inputs_dict = {
            "images0": input_array_0,
            "images1": input_array_1,
        }

        trt_outputs = do_inference(self.context, self.engine, inputs_dict=inputs_dict, output_allocator=self.output_allocator, stream=self.stream)
        matches = cp.asnumpy(trt_outputs[0])
        batch_indexes = cp.asnumpy(trt_outputs[1])

        return matches, batch_indexes

    def __call__(self, input_array_0, input_array_1):
        input_array_device_0 = cp.asarray(input_array_0, dtype=cp.float32)
        input_array_device_1 = cp.asarray(input_array_1, dtype=cp.float32)

        return self.inference_preprocessed(input_array_device_0, input_array_device_1)

    def __del__(self) -> None:
        if hasattr(self, 'stream'):
            if cudart != None:
                cuda_call(cudart.cudaStreamDestroy(self.stream))

def warp_corners_and_draw_matches(ref_points, dst_points, img1, img2):
    # Calculate the Homography matrix
    H, mask = cv2.findHomography(ref_points, dst_points, cv2.USAC_MAGSAC, 3.5, maxIters=1_000, confidence=0.999)
    mask = mask.flatten()

    # Get corners of the first image (image1)
    h, w = img1.shape[:2]
    corners_img1 = np.array([[0, 0], [w-1, 0], [w-1, h-1], [0, h-1]], dtype=np.float32).reshape(-1, 1, 2)

    # Warp corners to the second image (image2) space
    warped_corners = cv2.perspectiveTransform(corners_img1, H)

    # Draw the warped corners in image2
    img2_with_corners = img2.copy()
    for i in range(len(warped_corners)):
        start_point = tuple(warped_corners[i-1][0].astype(int))
        end_point = tuple(warped_corners[i][0].astype(int))
        cv2.line(img2_with_corners, start_point, end_point, (0, 255, 0), 4)  # Using solid green for corners

    # Prepare keypoints and matches for drawMatches function
    keypoints1 = [cv2.KeyPoint(p[0], p[1], 5) for p in ref_points]
    keypoints2 = [cv2.KeyPoint(p[0], p[1], 5) for p in dst_points]
    matches = [cv2.DMatch(i,i,0) for i in range(len(mask)) if mask[i]]

    # Draw inlier matches
    img_matches = cv2.drawMatches(img1, keypoints1, img2_with_corners, keypoints2, matches, None,
                                  matchColor=(0, 255, 0), flags=2)

    return img_matches

if __name__ == "__main__":
    onnx_file_path = './models/xfeat_matching.onnx'
    input_image_0_path = './images/ref.png'
    input_image_1_path = './images/tgt.png'

    xfeat = XFEAT(onnx_file_path, 0)
    image_0 = cv2.imread(input_image_0_path, cv2.IMREAD_COLOR)
    image_1 = cv2.imread(input_image_1_path, cv2.IMREAD_COLOR)

    batch_zise = 8  # Psuedo-batch the input images
    input_array_0 = np.expand_dims(image_0.transpose(2, 0, 1), axis=0).repeat(batch_zise, axis=0)
    input_array_1 = np.expand_dims(image_1.transpose(2, 0, 1), axis=0).repeat(batch_zise, axis=0)

    matches, batch_indexes = xfeat(input_array_0, input_array_1)
    mkpts_0, mkpts_1 = matches[batch_indexes == 0][..., :2], matches[batch_indexes == 0][..., 2:]

    img_matches = warp_corners_and_draw_matches(mkpts_0, mkpts_1, image_0, image_1)
    cv2.imshow("Matches", img_matches)
    cv2.waitKey(0)

    loop = 100
    start = time.time()
    for i in range(loop):
        matches, batch_indexes = xfeat(input_array_0, input_array_1)
    end = time.time()
    print("Time: ", (end - start) / loop / batch_zise)
    print("FPS: ", batch_zise * loop / (end - start))
guipotje commented 5 months ago

Hi @acai66 @IamShubhamGupto ,

Thank you very much for providing the ONNX examples in both C++ and Python. They look amazing and incredibly useful!

Are you still waiting for @IamShubhamGupto to review the merge? I am not sure if you both merged the work you did last month, so I am just checking in to see if I should start reviewing the PR.

Once again, thank you @acai66 and @IamShubhamGupto for the ONNX examples. I appreciate your effort to make XFeat deployment much better!

acai66 commented 5 months ago

Hi @acai66 @IamShubhamGupto ,

Thank you very much for providing the ONNX examples in both C++ and Python. They look amazing and incredibly useful!

Are you still waiting for @IamShubhamGupto to review the merge? I am not sure if you both merged the work you did last month, so I am just checking in to see if I should start reviewing the PR.

Once again, thank you @acai66 and @IamShubhamGupto for the ONNX examples. I appreciate your effort to make XFeat deployment much better!

Thank you for checking in. I am still waiting for @IamShubhamGupto's review. Unfortunately, I haven't received any response from @IamShubhamGupto regarding last month's work, which might be due to the busy graduation season.

IamShubhamGupto commented 5 months ago

Hi @acai66 @IamShubhamGupto , Thank you very much for providing the ONNX examples in both C++ and Python. They look amazing and incredibly useful! Are you still waiting for @IamShubhamGupto to review the merge? I am not sure if you both merged the work you did last month, so I am just checking in to see if I should start reviewing the PR. Once again, thank you @acai66 and @IamShubhamGupto for the ONNX examples. I appreciate your effort to make XFeat deployment much better!

Thank you for checking in. I am still waiting for @IamShubhamGupto's review. Unfortunately, I haven't received any response from @IamShubhamGupto regarding last month's work, which might be due to the busy graduation season.

Hey @acai66 @guipotje sorry to keep both of you waiting. I did just graduate and was busy with a few conferences and competitions. As for the development on this branch, I believe you should go ahead and merge this branch. I will be back on contributing to this project some time later.

noahzn commented 4 months ago

@acai66 Hi, thank you for your contribution. I found that only detectAndComputeDense is supported. Could you also make it support for detectAndCompute?

acai66 commented 4 months ago

@acai66 Hi, thank you for your contribution. I found that only detectAndComputeDense is supported. Could you also make it support for detectAndCompute?

commit: https://github.com/verlab/accelerated_features/pull/5/commits/f4a55c14fddbf61ea1f6de83c8a30baacecdc88b

noahzn commented 4 months ago

@acai66 Thank you very much!

noahzn commented 4 months ago

@acai66 Have you met a problem in DetectAndCompute that the #Select top-k features function exceeded the max number when using TRT Execution Provider.

stschake commented 4 months ago

Hey @acai66,

thanks for all your work on the ONNX export of XFeat, it's been very handy. @guipotje recently added the Lighterglue addon matcher so I spent some time to make that available for ONNX export on top of your changes, see the branch here:

https://github.com/stschake/accelerated_features/tree/feature/lighterglue-onnx

The upstream code uses kornia which isn't suitable for ONNX export, so I started with the LightGlue-ONNX implementation and modified it slightly to add things like keypoints normalization directly in the model, in the xfeat tradition.

yrik commented 3 months ago

Thanks guys! Looking forward for onnx version.