zju3dv / EfficientLoFTR

491 stars 34 forks source link

inference time not as fast as expected #7

Closed fabricecarles closed 2 months ago

fabricecarles commented 3 months ago

Hi, thanks for open-sourcing the code and model weights As i said in a previous post I would like to use EfficientLoFTR to do a comparative benchmark in our study.

I found strange results in my benchmark since in size 1x1x256x256 Efficient LoFTR inference time is close to 26 ms This is better than LoFTR witch run at ~40 ms for this resolution but very close to topicFMfast when measured on my GeForce RTX 2070 Mobile GPU

Since topicFMfast is not in your benchmark I would like to know if I do a mistake when using your code.

here is my inference code :

import time
import cv2
import numpy as np
import pytorch_lightning as pl
import argparse
import pprint
import torch
import kornia as K
import kornia.feature as KF
import matplotlib.pyplot as plt
from kornia_moons.viz import draw_LAF_matches

from loguru import logger as loguru_logger

from src.config.default import get_cfg_defaults
from src.utils.profiler import build_profiler

from src.lightning.data import MultiSceneDataModule
from src.lightning.lightning_loftr import PL_LoFTR

def parse_args():
    # init a costum parser which will be added into pl.Trainer parser
    # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '--data_cfg_path', type=str, default="configs/data/megadepth_test_1500.py", help='data config path')
    parser.add_argument(
        '--main_cfg_path', type=str, default="configs/loftr/eloftr_optimized.py", help='main config path')
    parser.add_argument(
        '--ckpt_path', type=str, default="weights/eloftr_outdoor.ckpt", help='path to the checkpoint')
    parser.add_argument(
        '--dump_dir', type=str, default=None, help="if set, the matching results will be dump to dump_dir")
    parser.add_argument(
        '--profiler_name', type=str, default=None, help='options: [inference, pytorch], or leave it unset')
    parser.add_argument(
        '--batch_size', type=int, default=1, help='batch_size per gpu')
    parser.add_argument(
        '--num_workers', type=int, default=2)
    parser.add_argument(
        '--thr', type=float, default=None, help='modify the coarse-level matching threshold.')
    parser.add_argument(
        '--pixel_thr', type=float, default=None, help='modify the RANSAC threshold.')
    parser.add_argument(
        '--ransac', type=str, default=None, help='modify the RANSAC method')
    parser.add_argument(
        '--scannetX', type=int, default=832, help='ScanNet resize X')
    parser.add_argument(
        '--scannetY', type=int, default=832, help='ScanNet resize Y')
    parser.add_argument(
        '--megasize', type=int, default=1152, help='MegaDepth resize')
    parser.add_argument(
        '--npe', action='store_true', default=False, help='')
    parser.add_argument(
        '--fp32', action='store_true', default=False, help='')
    parser.add_argument(
        '--ransac_times', type=int, default=None, help='repeat ransac multiple times for more robust evaluation')
    parser.add_argument(
        '--rmbd', type=int, default=None, help='remove border matches')
    parser.add_argument(
        '--deter', action='store_true', default=False, help='use deterministic mode for testing')

    parser = pl.Trainer.add_argparse_args(parser)
    return parser.parse_args()

def inplace_relu(m):
    classname = m.__class__.__name__
    if classname.find('ReLU') != -1:
        m.inplace=True

if __name__ == '__main__':
    # parse arguments
    args = parse_args()
    pprint.pprint(vars(args))

    # init default-cfg and merge it with the main- and data-cfg        
    config = get_cfg_defaults()
    config.merge_from_file(args.main_cfg_path)
    config.merge_from_file(args.data_cfg_path)
    if args.deter:
        torch.backends.cudnn.deterministic = True
    pl.seed_everything(config.TRAINER.SEED)  # reproducibility

    # tune when testing
    if args.thr is not None:
        config.LOFTR.MATCH_COARSE.THR = args.thr

    if args.scannetX is not None and args.scannetY is not None:
        config.DATASET.SCAN_IMG_RESIZEX = args.scannetX
        config.DATASET.SCAN_IMG_RESIZEY = args.scannetY
    if args.megasize is not None:
        config.DATASET.MGDPT_IMG_RESIZE = args.megasize

    if args.npe:
        if config.LOFTR.COARSE.ROPE:
            assert config.DATASET.NPE_NAME is not None
        if config.DATASET.NPE_NAME is not None:
            if config.DATASET.NPE_NAME == 'megadepth':
                config.LOFTR.COARSE.NPE = [832, 832, config.DATASET.MGDPT_IMG_RESIZE, config.DATASET.MGDPT_IMG_RESIZE] # [832, 832, 1152, 1152]
            elif config.DATASET.NPE_NAME == 'scannet':
                config.LOFTR.COARSE.NPE = [832, 832, config.DATASET.SCAN_IMG_RESIZEX, config.DATASET.SCAN_IMG_RESIZEX] # [832, 832, 640, 640]
    else:
        config.LOFTR.COARSE.NPE = [832, 832, 832, 832]

    if args.ransac_times is not None:
        config.LOFTR.EVAL_TIMES = args.ransac_times

    if args.rmbd is not None:
        config.LOFTR.MATCH_COARSE.BORDER_RM = args.rmbd

    if args.pixel_thr is not None:
        config.TRAINER.RANSAC_PIXEL_THR = args.pixel_thr

    if args.ransac is not None:
        config.TRAINER.POSE_ESTIMATION_METHOD = args.ransac
        if args.ransac == 'LO-RANSAC' and config.TRAINER.RANSAC_PIXEL_THR == 0.5:
            config.TRAINER.RANSAC_PIXEL_THR = 2.0

    if args.fp32:
        config.LOFTR.FP16 = False

    loguru_logger.info(f"Args and config initialized!")

    # lightning module
    profiler = build_profiler(args.profiler_name)
    model = PL_LoFTR(config, pretrained_ckpt=args.ckpt_path, profiler=profiler, dump_dir=args.dump_dir)
    loguru_logger.info(f"LoFTR-lightning initialized!")
    model.matcher = model.matcher.eval().cuda()
    # model.matcher = torch.compile(model.matcher)

    print('start inference')
    # Load example images
    img0_pth = "assets/01.BMP"
    img1_pth = "assets/02.BMP"
    img0_raw = cv2.imread(img0_pth, cv2.IMREAD_GRAYSCALE)
    img1_raw = cv2.imread(img1_pth, cv2.IMREAD_GRAYSCALE)
    size = 256
    img0_raw = cv2.resize(img0_raw, (size, size))  # input size shuold be divisible by 8
    img1_raw = cv2.resize(img1_raw, (size, size))
    img0 = torch.from_numpy(img0_raw)[None][None].cuda() / 255.
    img1 = torch.from_numpy(img1_raw)[None][None].cuda() / 255.
    data_dict = {'image0': img0, 'image1': img1, 'pair_names': ('01', '02'), 'dataset_name' : 'scan4all'}
    print('image 0 size', img0.shape)
    print('image 1 size', img1.shape)
    # inference (with warmup)
    num_inferences = 105
    times = np.zeros(num_inferences)
    with torch.no_grad():
        with torch.autocast(enabled=config.LOFTR.FP16, device_type='cuda', dtype=torch.float16):
            for i in range(num_inferences):
                torch.cuda.current_stream().synchronize()
                t0 = time.time()
                model.matcher(data_dict)
                torch.cuda.current_stream().synchronize()
                t1 = time.time()
                current_time = (t1 - t0) *1000
                print(f"inference pytorch {current_time :.1f} [ms]")
                times[i] = current_time
    print('times ', times)
    print(f"average inference time = {times[5:].mean() :.1f} [ms] std {times[5:].std() :.1f} for {num_inferences - 5} samples")
    print('data_dict.keys()', data_dict.keys())     
    print('mconf', data_dict['mconf'].shape)
    print('data_dict', data_dict['mkpts0_f'].shape)
    print('data_dict', data_dict['mkpts1_f'].shape)
    # print('mconf', data_dict['mconf'])
    mkpts0 = data_dict['mkpts0_f']
    mkpts1 = data_dict['mkpts1_f']
    mconf = data_dict['mconf']
    mkpts0 = mkpts0.cpu().numpy()
    mkpts1 = mkpts1.cpu().numpy()
    # inliers filtering
    mconf = mconf.unsqueeze(1)
    mconf = mconf.cpu().numpy()
    mconf = mconf > 0.2
    print("mconf", mconf.shape)
    # plot matchs

    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    draw_LAF_matches(
        KF.laf_from_center_scale_ori(
            torch.from_numpy(mkpts0).view(1, -1, 2),
            torch.ones(mkpts0.shape[0]).view(1, -1, 1, 1),
            torch.ones(mkpts0.shape[0]).view(1, -1, 1),
        ),
        KF.laf_from_center_scale_ori(
            torch.from_numpy(mkpts1).view(1, -1, 2),
            torch.ones(mkpts1.shape[0]).view(1, -1, 1, 1),
            torch.ones(mkpts1.shape[0]).view(1, -1, 1),
        ),
        torch.arange(mkpts0.shape[0]).view(-1, 1).repeat(1, 2),
        K.tensor_to_image(img0),
        K.tensor_to_image(img1),
        mconf,
        draw_dict={"inlier_color": (0.2, 1, 0.2), "tentative_color": None, "feature_color": (0.2, 0.5, 1), "vertical": False},
        ax=ax
    )
    plt.savefig(f"assets/output_filtered_by_confidence_size{size}_num-match{len(mconf)}_{(t1 - t0) *1000 :0.1f}_ms.png")

here is my environment setup :

conda env create -f environment.yaml
conda activate eloftr
pip install torch==2.0.0+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install -r requirements.txt 
pip install kornia_moons
python inference.py

Did I miss something to make your code more efficient ?

Bests

wyf2020 commented 3 months ago

Thank you for sharing your results. Here are some suggestions at first glance.

  1. Please use reparameterization before inference like this line: https://github.com/zju3dv/EfficientLoFTR/blob/9cb9ca002a264a92fbf8a39ffb4f09dcfadaeda1/src/lightning/lightning_loftr.py#L223 which is significant to inference speed.
  2. We only compare with TopicFM, as TopicFM+ employs a significantly higher number of OpenCV RANSAC iterations (10k vs. the standard 1k in other baselines) in their code, which greatly improves AUC but also substantially slows down RANSAC. Evaluating inference speed without considering accuracy isn't meaningful.

    Megadepth AUC@(5,10,20)
    LoFTR 52.8 / 69.2 / 81.2
    TopicFM 54.1 / 70.1 / 81.6
    TopicFM+ 52.2 / 68.8 / 81.1
    Ours 56.4 / 72.2 / 83.5
    Ours (Opt.) 55.4 / 71.4 / 82.9
    TopicFM+(10k) 58.2 / 72.8 / 83.2
    Ours(10k) 59.3 / 74.1 / 84.6
  3. If you run our timing scripts on an RTX3090, you will achieve the exact timings of 34ms (Full) and 27ms (Opt.) as reported in our paper.

We will provide a jupyter notebook demo to show how to use our model later, please stay tuned!

fabricecarles commented 3 months ago

thanks for your advice with self.matcher = reparameter(self.matcher) inference time is improved a little bit

fabricecarles commented 3 months ago

in your readme you plan to Add options of flash-attention and torch.compiler for better performance is there other performance improvements expected ?

wyf2020 commented 3 months ago

Sorry for the late reply. Yes, there's also FP16 inference. We have already modified some of the code and added a Jupyter notebook to demonstrate how to use FP16 inference (on modern GPUs) to accelerate our model. This will provide even faster speeds than mixed precision with almost no loss in accuracy.