fabio-sim / LightGlue-ONNX

ONNX-compatible LightGlue: Local Feature Matching at Light Speed. Supports TensorRT, OpenVINO
Apache License 2.0
313 stars 26 forks source link

the possiblity of supporting batch input #80

Open noahzn opened 2 weeks ago

noahzn commented 2 weeks ago

Hi @fabio-sim Now the repo only supports for batchsize =1, do you think it's possible that if not enough keypoints are extracted, we can use a random array to make them have the same number of keypoints. For example, if the input is 2XNX2, for image 1 N1=128, for image 2 N2=125, can we stack three random array as the fake points so that we can run it in a batch mode?

fabio-sim commented 2 weeks ago

Hello @noahzn, thanks for your interest again. I'll see what I can do.

noahzn commented 1 week ago

Thank you! I will be waiting for your thoughts.

fabio-sim commented 6 days ago

I've added batch input support in 9ebf215. Rather than padding with a random array, I've decided to go with another design choice instead; details here: https://fabio-sim.github.io/blog/accelerating-lightglue-inference-onnx-runtime-tensorrt/

noahzn commented 6 days ago

That's really amazing! I will take a careful look and give you feedback. Thanks a lot!

noahzn commented 5 days ago

Hi @fabio-sim , I noticed that you also modified this file, but you didn't use it in exporting models. Can I use it if I want to export non-end2end models using batch input? My two image batch have different numbers of keypoints. For example, keypoints of image1 are always (B X 100 X 2), and image2's are always (B X 200 X2)

fabio-sim commented 5 days ago

Hi, that file is from the original impl, so it's unrelated to export.

For your use case, I recommend passing the left and right batches separately then, like this: (note: untested):

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from ..config import Extractor
from ..ops import multi_head_attention_dispatch

torch.backends.cudnn.deterministic = True

class LearnableFourierPositionalEncoding(nn.Module):
    def __init__(self, M: int, descriptor_dim: int, num_heads: int, gamma: float = 1.0) -> None:
        super().__init__()
        self.num_heads = num_heads
        head_dim = descriptor_dim // num_heads
        self.Wr = nn.Linear(M, head_dim // 2, bias=False)
        self.gamma = gamma
        nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """encode position vector"""
        projected = self.Wr(x)
        cosines, sines = torch.cos(projected), torch.sin(projected)
        emb = torch.stack([cosines, sines])
        return emb.repeat_interleave(2, dim=3).repeat(1, 1, 1, self.num_heads).unsqueeze(4)

class TokenConfidence(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())

    def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
        """get confidence tokens"""
        return (
            self.token(desc0.detach()).squeeze(-1),
            self.token(desc1.detach()).squeeze(-1),
        )

class SelfBlock(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, bias: bool = True) -> None:
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.ffn = nn.Sequential(
            nn.Linear(2 * embed_dim, 2 * embed_dim),
            nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
            nn.GELU(),
            nn.Linear(2 * embed_dim, embed_dim),
        )

    def forward(self, x: torch.Tensor, encoding: torch.Tensor) -> torch.Tensor:
        b, n, _ = x.shape
        qkv: torch.Tensor = self.Wqkv(x)
        qkv = qkv.reshape((b, n, self.embed_dim, 3))
        qk, v = qkv[..., :2], qkv[..., 2]
        qk = self.apply_cached_rotary_emb(encoding, qk)
        q, k = qk[..., 0], qk[..., 1]
        context = multi_head_attention_dispatch(q, k, v, self.num_heads)
        message = self.out_proj(context)
        return x + self.ffn(torch.concat([x, message], 2))

    def rotate_half(self, qk: torch.Tensor) -> torch.Tensor:
        b, n, _, _ = qk.shape
        qk = qk.reshape((b, n, self.num_heads, self.head_dim // 2, 2, 2))
        qk = torch.stack((-qk[..., 1, :], qk[..., 0, :]), dim=4)
        qk = qk.reshape((b, n, self.embed_dim, 2))
        return qk

    def apply_cached_rotary_emb(self, encoding: torch.Tensor, qk: torch.Tensor) -> torch.Tensor:
        return qk * encoding[0] + self.rotate_half(qk) * encoding[1]

class CrossBlock(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, bias: bool = True) -> None:
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.to_qk = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.to_v = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.to_out = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.ffn = nn.Sequential(
            nn.Linear(2 * embed_dim, 2 * embed_dim),
            nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
            nn.GELU(),
            nn.Linear(2 * embed_dim, embed_dim),
        )

    def forward(self, descriptors0: torch.Tensor, descriptors1: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        b, _, _ = descriptors0.shape
        qk0, v0 = self.to_qk(descriptors0), self.to_v(descriptors0)
        qk1, v1 = self.to_qk(descriptors1), self.to_v(descriptors1)

        m0 = multi_head_attention_dispatch(qk0, qk1, v1, self.num_heads)
        m0 = self.to_out(m0)
        descriptors0 = descriptors0 + self.ffn(torch.concat([descriptors0, m0], 2))

        m1 = multi_head_attention_dispatch(qk1, qk0, v0, self.num_heads)
        m1 = self.to_out(m1)
        descriptors1 = descriptors1 + self.ffn(torch.concat([descriptors1, m1], 2))
        return descriptors0, descriptors1

class TransformerLayer(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int):
        super().__init__()
        self.self_attn = SelfBlock(embed_dim, num_heads)
        self.cross_attn = CrossBlock(embed_dim, num_heads)

    def forward(
        self, descriptors0: torch.Tensor, descriptors1: torch.Tensor, encodings0: torch.Tensor, encodings1: torch.Tensor
    ) -> torch.Tensor:
        descriptors0 = self.self_attn(descriptors0, encodings0)
        descriptors1 = self.self_attn(descriptors1, encodings1)
        return self.cross_attn(descriptors0, descriptors1)

def sigmoid_log_double_softmax(similarities: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor) -> torch.Tensor:
    """create the log assignment matrix from logits and similarity"""
    certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
    scores0 = F.log_softmax(similarities, 2)
    scores1 = F.log_softmax(similarities, 1)
    scores = scores0 + scores1 + certainties
    return scores

class MatchAssignment(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.scale = dim**0.25
        self.final_proj = nn.Linear(dim, dim, bias=True)
        self.matchability = nn.Linear(dim, 1, bias=True)

    def forward(self, descriptors0: torch.Tensor, descriptors1: torch.Tensor) -> torch.Tensor:
        """build assignment matrix from descriptors"""
        mdescriptors0 = self.final_proj(descriptors0) / self.scale
        mdescriptors1 = self.final_proj(descriptors1) / self.scale
        similarities = mdescriptors0 @ mdescriptors1.transpose(1, 2)
        z0 = self.matchability(descriptors0)
        z1 = self.matchability(descriptors1)
        scores = sigmoid_log_double_softmax(similarities, z0, z1)
        return scores

    def get_matchability(self, desc: torch.Tensor):
        return torch.sigmoid(self.matchability(desc)).squeeze(-1)

def filter_matches(scores: torch.Tensor, threshold: float):
    """obtain matches from a log assignment matrix [BxNxN]"""
    max0 = torch.topk(scores, k=1, dim=2, sorted=False)  # scores.max(2)
    max1 = torch.topk(scores, k=1, dim=1, sorted=False)  # scores.max(1)
    m0, m1 = max0.indices[:, :, 0], max1.indices[:, 0, :]

    indices = torch.arange(m0.shape[1], device=m0.device).expand_as(m0)
    mutual = indices == m1.gather(1, m0)
    mscores = max0.values[:, :, 0].exp()
    valid = mscores > threshold

    b_idx, m0_idx = torch.where(valid & mutual)
    m1_idx = m0[b_idx, m0_idx]
    matches = torch.concat([b_idx[:, None], m0_idx[:, None], m1_idx[:, None]], 1)
    mscores = mscores[b_idx, m0_idx]
    return matches, mscores

class LightGlue(nn.Module):
    version = "v0.1_arxiv"
    url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"

    def __init__(
        self,
        extractor: Extractor,
        descriptor_dim: int = 256,
        num_heads: int = 4,
        n_layers: int = 9,
        filter_threshold: float = 0.1,  # match threshold
        depth_confidence: float = -1,  # -1 is no early stopping, recommend: 0.95
        width_confidence: float = -1,  # -1 is no point pruning, recommend: 0.99
    ) -> None:
        super().__init__()

        self.descriptor_dim = descriptor_dim
        self.num_heads = num_heads
        self.n_layers = n_layers
        self.filter_threshold = filter_threshold
        self.depth_confidence = depth_confidence
        self.width_confidence = width_confidence

        if extractor.dim != self.descriptor_dim:
            self.input_proj = nn.Linear(extractor.dim, self.descriptor_dim, bias=True)
        else:
            self.input_proj = nn.Identity()

        self.posenc = LearnableFourierPositionalEncoding(2, self.descriptor_dim, self.num_heads)

        d, h, n = self.descriptor_dim, self.num_heads, self.n_layers

        self.transformers = nn.ModuleList([TransformerLayer(d, h) for _ in range(n)])

        self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])

        self.token_confidence = nn.ModuleList([TokenConfidence(d) for _ in range(n - 1)])
        self.register_buffer(
            "confidence_thresholds",
            torch.Tensor([self.confidence_threshold(i) for i in range(n)]),
        )

        state_dict = torch.hub.load_state_dict_from_url(self.url.format(self.version, extractor.value))

        # rename old state dict entries
        for i in range(n):
            pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
            state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
            pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
            state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
        self.load_state_dict(state_dict, strict=False)

    def forward(
        self,
        keypoints0: torch.Tensor,  # (2B, N, 2), normalized
        keypoints1: torch.Tensor,
        descriptors0: torch.Tensor,  # (2B, N, D)
        descriptors1: torch.Tensor,
    ):
        descriptors0 = self.input_proj(descriptors0)
        descriptors1 = self.input_proj(descriptors1)

        # positional embeddings
        encodings0 = self.posenc(keypoints0)  # (2, 2B, *, 64, 1)
        encodings1 = self.posenc(keypoints1)

        # GNN + final_proj + assignment
        for i in range(self.n_layers):
            # self+cross attention
            descriptors0, descriptors1 = self.transformers[i](descriptors0, descriptors1, encodings0, encodings1)

        scores = self.log_assignment[i](descriptors0, descriptors1)  # (B, N, N)
        matches, mscores = filter_matches(scores, self.filter_threshold)
        return matches, mscores  # (M, 3), (M,)

    def confidence_threshold(self, layer_index: int) -> float:
        """scaled confidence threshold"""
        threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.n_layers)
        return np.clip(threshold, 0, 1)

    def get_pruning_mask(
        self,
        confidences: torch.Tensor | None,
        scores: torch.Tensor,
        layer_index: int,
    ) -> torch.Tensor:
        """mask points which should be removed"""
        keep = scores > (1 - self.width_confidence)
        if confidences is not None:  # Low-confidence points are never pruned.
            keep |= confidences <= self.confidence_thresholds[layer_index]
        return keep

    def check_if_stop(
        self,
        confidences0: torch.Tensor,
        confidences1: torch.Tensor,
        layer_index: int,
        num_points: int,
    ) -> torch.Tensor:
        """evaluate stopping condition"""
        confidences = torch.cat([confidences0, confidences1], -1)
        threshold = self.confidence_thresholds[layer_index]
        ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
        return ratio_confident > self.depth_confidence

and then adjusting the Pipeline class to orchestrate SuperPoint(100) and SuperPoint(200) accordingly.

noahzn commented 3 days ago

Hi @fabio-sim thank you very much! I'm now working on the code. But I met an error

fused_multi_head_attention = torch.library.custom_op(CUSTOM_OP_NAME, mutates_args=())(multi_head_attention)
AttributeError: module 'torch.library' has no attribute 'custom_op'

My torch is >=2.1

fabio-sim commented 3 days ago

Oh apologies, my mistake. torch.library.custom_op needs torch >= 2.4. I should've put a check. I think it's fine if you comment it out

noahzn commented 1 day ago

@fabio-sim Thank you for your comments.

orig_image0 = cv2.imread(img0_path, cv2.IMREAD_COLOR)
orig_image1 = cv2.imread(img1_path, cv2.IMREAD_COLOR)
viz2d.plot_images(
    [orig_image0, orig_image1]
)

assert np.all(kpts0[2][matches[..., 1]] == kpts0[0][matches[..., 1]])
assert np.all(kpts1[2][matches[..., 2]] == kpts1[0][matches[..., 2]])
viz2d.plot_matches(kpts0[0][matches[..., 1]], kpts1[0][matches[...,2]], color="lime", lw=0.2)

viz2d.save_plot('aaa1.jpg', dpi=300)
viz2d.plt.show()
viz2d.plot_matches(kpts0[2][matches[..., 1]], kpts1[2][matches[..., 2]], color="lime", lw=0.2)
viz2d.save_plot('aaa2.jpg', dpi=300)
viz2d.plt.show()

I used the above code to visualize. I used batchsize=4, and for the first and the third image pairs, they are the same, and for the other two image pairs I used random arrays. Here I assert that the output for the first and the third pairs are the same. However, when visualizing the results, there are always several matches are changing and incorrect. Do you know the reason?

myplot1

myplot2

update: The problem has been solved. I didn't parse the returned matches correctly. Now it works. Thanks a million for your help!! Now I close this ticket

noahzn commented 21 hours ago

Hi, sorry, I still have a problem.

def multi_head_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int) -> torch.Tensor:
    b, n, d = q.shape
    head_dim = d // num_heads
    q, k, v = (t.reshape((b, n, num_heads, head_dim)).transpose(1, 2) for t in (q, k, v))
    return F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape((b, n, d))

I found that when image pairs have different numbers of keypoints, the multi_head_attention will throw an error. For example, for the left images the dimension is (2, 99, 64), and for the right images the dimension is (2, 256, 64). Here 256 is the max_number of keypoints I set. but it extracts 99 keypoints on the left images. Then in the multi_head_attention function it throws the error q, k, v = (t.reshape((b, n, num_heads, head_dim)).transpose(1, 2) for t in (q, k, v)) RuntimeError: shape '[2, 99, 4, 16]' is invalid for input of size 32768

because for the right images it's [2, 256, 4, 16]: 2x256x4x16=32768 (Please notice here that my keypoint descriptor is 64D, instead of 256D. It's a customized network).

I tried to modify the function as follows, the error was gone but the matching result is bad.

def multi_head_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int) -> torch.Tensor:
    b, n, d = q.shape
    nk = k.shape[1]
    head_dim = d // num_heads
    q = q.reshape(b, n, num_heads, head_dim).transpose(1, 2)
    k = k.reshape(b, nk, num_heads, head_dim).transpose(1, 2)
    v = v.reshape(b, nk, num_heads, head_dim).transpose(1, 2)
    return F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape((b, n, d))

Could you help me with that? Thank you in advance!

Update: I have used the old implementation for CrossBlock and it works with different numbers of keypoints.

fabio-sim commented 5 hours ago

Ah yes, if you have different number of keypoints, that means the sequence length of Q is different from that of K & V.

Use something like this instead:

def multi_head_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int) -> torch.Tensor:
    b, n, d = q.shape
    _, n1, _ = k.shape
    head_dim = d // num_heads
    q = q.reshape((b, n, num_heads, head_dim)).transpose(1, 2)
    k, v = (t.reshape((b, n1, num_heads, head_dim)).transpose(1, 2) for t in (k, v))
    return F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape((b, n, d))