cvg / LightGlue

LightGlue: Local Feature Matching at Light Speed (ICCV 2023)
Apache License 2.0
3.15k stars 291 forks source link

Getting features in a video from a stereo camera #80

Closed abhishekmonogram closed 5 months ago

abhishekmonogram commented 8 months ago

Hello, I have been trying to use lightglue to extract features from 2 consecutive frames obtained from a stereo camera (real time camera stream). However everytime I try to extract features between two frames, I get RuntimeError: step must be nonzero . Can you please help me resolve this issue?

abhishekmonogram commented 8 months ago

A follow up to the above thread is that. I have observed that the number of features extracted between consecutive frames are too low. I am attaching an image for you reference. I think there should be more matches in the scene. Can someone tell me why this is the case? test

abhishekmonogram commented 8 months ago

@sarlinpe Can you please have a look at this issue and help me out?

sarlinpe commented 8 months ago

I get RuntimeError: step must be nonzero

I can't help without the full logs, reproduction code, and input data.

abhishekmonogram commented 8 months ago

I get RuntimeError: step must be nonzero

I can't help without the full logs, reproduction code, and input data.

Here is the code I used.

import pyzed.sl as sl
import cv2
import numpy as np

import sys
import viewer as gl
import pyzed.sl as sl
import argparse

from polygon_draw import PolygonDrawer

from lightglue import LightGlue, SuperPoint, DISK
from lightglue.utils import load_image, rbd
from lightglue import viz2d
import torch
from collections import deque
import matplotlib.pyplot as plt

torch.set_grad_enabled(False)

def parse_args(init):
    if len(opt.input_svo_file)>0 and opt.input_svo_file.endswith(".svo"):
        init.set_from_svo_file(opt.input_svo_file)
        print("[Sample] Using SVO File input: {0}".format(opt.input_svo_file))
    elif len(opt.ip_address)>0 :
        ip_str = opt.ip_address
        if ip_str.replace(':','').replace('.','').isdigit() and len(ip_str.split('.'))==4 and len(ip_str.split(':'))==2:
            init.set_from_stream(ip_str.split(':')[0],int(ip_str.split(':')[1]))
            print("[Sample] Using Stream input, IP : ",ip_str)
        elif ip_str.replace(':','').replace('.','').isdigit() and len(ip_str.split('.'))==4:
            init.set_from_stream(ip_str)
            print("[Sample] Using Stream input, IP : ",ip_str)
        else :
            print("Unvalid IP format. Using live stream")
    if ("HD2K" in opt.resolution):
        init.camera_resolution = sl.RESOLUTION.HD2K
        print("[Sample] Using Camera in resolution HD2K")
    elif ("HD1200" in opt.resolution):
        init.camera_resolution = sl.RESOLUTION.HD1200
        print("[Sample] Using Camera in resolution HD1200")
    elif ("HD1080" in opt.resolution):
        init.camera_resolution = sl.RESOLUTION.HD1080
        print("[Sample] Using Camera in resolution HD1080")
    elif ("HD720" in opt.resolution):
        init.camera_resolution = sl.RESOLUTION.HD720
        print("[Sample] Using Camera in resolution HD720")
    elif ("SVGA" in opt.resolution):
        init.camera_resolution = sl.RESOLUTION.SVGA
        print("[Sample] Using Camera in resolution SVGA")
    elif ("VGA" in opt.resolution):
        init.camera_resolution = sl.RESOLUTION.VGA
        print("[Sample] Using Camera in resolution VGA")
    elif len(opt.resolution)>0: 
        print("[Sample] No valid resolution entered. Using default")
    else : 
        print("[Sample] Using default resolution")

def main():
    print("Running Depth Sensing sample ... Press 'Esc' to quit\nPress 's' to save the point cloud")

    init = sl.InitParameters(depth_mode=sl.DEPTH_MODE.ULTRA,
                                 coordinate_units=sl.UNIT.METER,
                                 coordinate_system=sl.COORDINATE_SYSTEM.RIGHT_HANDED_Y_UP)
    parse_args(init)
    zed = sl.Camera()
    status = zed.open(init)
    if status != sl.ERROR_CODE.SUCCESS:
        print(repr(status))
        exit()

    camera_model = zed.get_camera_information().camera_model
    res = zed.get_camera_information().camera_configuration.resolution

    # Create OpenGL viewer
    viewer = gl.GLViewer()
    viewer.init(1, sys.argv, camera_model, res)

    point_cloud = sl.Mat(res.width, res.height, sl.MAT_TYPE.F32_C4, sl.MEM.CPU)
    image_zed = sl.Mat(res.width, res.height, sl.MAT_TYPE.U8_C4)

    #Queue for lightglue
    S = 2 

    new_frame_counter = 0
    new_frame_req = S
    frame_buffer = deque(maxlen=S)
    curr_frame_count = 0

    while viewer.is_available() and viewer_rgb.is_available():
        if zed.grab() == sl.ERROR_CODE.SUCCESS:
            curr_frame_count+=1

            if curr_frame_count ==1:
                extractor = SuperPoint(max_num_keypoints=4096).eval().to(opt.device)  # load the extractor
                matcher = LightGlue(features="superpoint", depth_confidence=-1, width_confidence=-1,filter_threshold=0.9).eval().to(opt.device)

            zed.retrieve_image(image_zed, sl.VIEW.LEFT)
            # Use get_data() to get the numpy array
            image_ocv = image_zed.get_data()
            frame_buffer.append(torch.Tensor(image_ocv).permute(2,0,1))

            zed.retrieve_measure(point_cloud, sl.MEASURE.XYZRGBA,sl.MEM.CPU, res)
            viewer.updateData(point_cloud)

            point_cloud_data = point_cloud.get_data()
            point3D = point_cloud.get_value(33,33)

            if len(frame_buffer)==S:
                feats0 = extractor.extract(frame_buffer[0][:3,:,:].to(opt.device))
                feats1 = extractor.extract(frame_buffer[1][:3,:,:].to(opt.device))

                matches01 = matcher({"image0": feats0, "image1": feats1})

                feats0, feats1, matches01 = [
                    rbd(x) for x in [feats0, feats1, matches01]
                ]  # remove batch dimension

                kpts0, kpts1, matches = feats0["keypoints"], feats1["keypoints"], matches01["matches"]
                m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]

                axes = viz2d.plot_images([cv2.cvtColor(frame_buffer[0].permute(1,2,0).cpu().numpy().astype(np.uint8),cv2.COLOR_BGR2RGB)[:,:,:3], cv2.cvtColor(frame_buffer[1].permute(1,2,0).cpu().numpy().astype(np.uint8),cv2.COLOR_BGR2RGB)[:,:,:3]])

                viz2d.plot_matches(m_kpts0, m_kpts1, color="lime", lw=0.2)

                viz2d.add_text(0, f'Stop after {matches01["stop"]} layers')

                viz2d.save_plot(f"frames/test_{new_frame_counter}.png")

                curr_frame_count-=1
                new_frame_counter+=1
                print(f'{curr_frame_count=}')
                print(f'{new_frame_counter=}')
                print(f'{matches.shape=}')

            if(viewer.save_data == True):
                point_cloud_to_save = sl.Mat()
                zed.retrieve_measure(point_cloud_to_save, sl.MEASURE.XYZRGBA, sl.MEM.CPU)
                err = point_cloud_to_save.write('Pointcloud.ply')
                if(err == sl.ERROR_CODE.SUCCESS):
                    print("Current .ply file saving succeed")
                else:
                    print("Current .ply file failed")
                viewer.save_data = False
    viewer.exit()
    zed.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_svo_file', type=str, help='Path to an .svo file, if you want to replay it',default = '')
    parser.add_argument('--ip_address', type=str, help='IP Adress, in format a.b.c.d:port or a.b.c.d, if you have a streaming setup', default = '')
    parser.add_argument('--resolution', type=str, help='Resolution, can be either HD2K, HD1200, HD1080, HD720, SVGA or VGA', default = '')
    parser.add_argument('--device', type=str, help='GPU(cuda) or CPU(cpu)', default = 'cuda')
    opt = parser.parse_args()
    if len(opt.input_svo_file)>0 and len(opt.iogl_viewer.p_address)>0:
        print("Specify only input_svo_file or ip_address, or none to use wired camera, not both. Exit program")
        exit()
    main() 

@sarlinpe This is directly taking data from the zed stereo camera and trying to get matches between consecutive frames. The entire changes to the repo can be found here

Error log :

Traceback (most recent call last): File "/home/abhishek/LightGlue/zed_live_tracking.py", line 201, in main() File "/home/abhishek/LightGlue/zed_live_tracking.py", line 141, in main matches01 = matcher({"image0": feats0, "image1": feats1}) File "/home/abhishek/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) File "/home/abhishek/LightGlue/lightglue/lightglue.py", line 441, in forward return self._forward(data) File "/home/abhishek/LightGlue/lightglue/lightglue.py", line 492, in _forward desc0, desc1 = self.transformers[i]( File "/home/abhishek/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/home/abhishek/LightGlue/lightglue/lightglue.py", line 242, in forward desc0 = self.self_attn(desc0, encoding0) File "/home/abhishek/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/home/abhishek/LightGlue/lightglue/lightglue.py", line 161, in forward context = self.inner_attn(q, k, v, mask=mask) File "/home/abhishek/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/home/abhishek/LightGlue/lightglue/lightglue.py", line 111, in forward v = F.scaled_dot_product_attention(args, attn_mask=mask).to(q.dtype) RuntimeError: step must be nonzero

sarlinpe commented 7 months ago

It should be fixed by https://github.com/cvg/LightGlue/pull/92.