Shiaoming / ALIKED

ALIKED: A Lighter Keypoint and Descriptor Extraction Network via Deformable Transformation
https://arxiv.org/pdf/2304.03608.pdf
BSD 3-Clause "New" or "Revised" License
184 stars 16 forks source link

About reliable loss #11

Open senz1024 opened 1 year ago

senz1024 commented 1 year ago

Hello,

I read your excellent paper and have a question about reliable loss in Section V.D.

According to the definition (Eq.(12)), the reliability seems to be a vector. In order to consider it as a scalar (like ALIKE), can I interpret it as $r(\mathbf{P_A},I_B) = \mathbf{P_B} \cdot \text{softmax}(\text{sim}(\mathbf{d_A}, \mathbf{DB})/t{rel})$ using the $\mathbf{P_B}$ mentioned in section V.C?

Shiaoming commented 1 year ago

You are welcome try any idea you want. In our early testing, this strategy seemed to have little impact on the results.

senz1024 commented 1 year ago

Thank you for the quick response and testing.

Do you mean $r(\mathbf{P_A},I_B)$ is originally a vector? If so, how do I calculate Eq.(13) with vector $r(\mathbf{P_A},I_B)$?

I'm sorry if I have misunderstood something.

HatsuneMikuuuu commented 10 months ago

Hi, excuse me, have you understood this question? I'm also confused about it.

PierreCarceller commented 4 months ago

@senz1024 Have you finally figured it out? I'm asking myself exactly the same question as you...

@Shiaoming Did you have more details to provide ?

senz1024 commented 4 months ago

I still don't understand.

Shiaoming commented 4 months ago

Thank you for the quick response and testing.

Do you mean r(PA,IB) is originally a vector? If so, how do I calculate Eq.(13) with vector r(PA,IB)?

I'm sorry if I have misunderstood something.

Sorry for the late reply.

There are some misleading and unclear in equation 12. $softmax(sim(d_A,DB)/t{rel})$ is indeed a vector, and the value in the vector of the corresponding location (the matching point) in $I_B$ is regarded as reliability.

PierreCarceller commented 4 months ago

Thank you for your reply. I had come to the same conclusion. But it's good to have your confirmation!

@senz1024 below is an implementation of ReliabilityLoss based on my understanding.

class ReliableLoss(object):

    def __init__(self, reliability_threshold: float):
        self.reliability_threshold = reliability_threshold

    def __call__(self, pred0: dict, pred1: dict, correspondences: dict):
        b, c, h, w = pred0['scores_map'].shape

        loss_mean = 0
        count = 0

        for idx in range(b):

            if correspondences[idx]['correspondence0'] is None:
                continue

            correspondences_idx_0 = correspondences[idx]['correspondence0']
            correspondences_idx_1 = correspondences[idx]['correspondence1']

            all_descriptors0 = pred0['descriptors'][idx]
            all_descriptors1 = pred1['descriptors'][idx]

            descriptors0 = all_descriptors0[correspondences_idx_0].detach()
            descriptors1 = all_descriptors1[correspondences_idx_1].detach()

            # FIXME Maybe use all descriptors ??
            # matching_similarity01 = (descriptors0 @ all_descriptors1.t()) / self.reliability_threshold
            matching_similarity01 = (descriptors0 @ descriptors1.t()) / self.reliability_threshold
            matching_similarity01 = torch.softmax(matching_similarity01, dim=1)

            # FIXME Maybe use all descriptors ??
            # matching_similarity10 = (descriptors1 @ all_descriptors0.t()) / self.reliability_threshold
            matching_similarity10 = (descriptors1 @ descriptors0.t()) / self.reliability_threshold
            matching_similarity10 = torch.softmax(matching_similarity10, dim=1)

            C01 = torch.diagonal(matching_similarity01)
            C10 = torch.diagonal(matching_similarity10)

            # C01 = matching_similarity01[torch.arange(matching_similarity01.shape[0]), correspondences_idx_1]
            # C10 = matching_similarity10[torch.arange(matching_similarity10.shape[0]), correspondences_idx_0]

            scores0 = pred0['scores'][idx][correspondences_idx_0]
            scores1 = pred1['scores'][idx][correspondences_idx_1]

            l0 = ((1 - C01) * scores0).sum() / (scores0.sum() + 1e-6)
            l1 = ((1 - C10) * scores1).sum() / (scores1.sum() + 1e-6)

            loss_mean += (l0 + l1) / 2
            count += 1

        # loss_mean /= count if count != 0 else pred0['scores_map'].new_tensor(0)
        if count != 0:
            loss_mean /= count
        else:
            loss_mean = pred0['scores_map'].new_tensor(0)
        assert not torch.isnan(loss_mean), f"ReliableLoss is nan count : {count}, loss_mean : {loss_mean}"
        return loss_mean

@Shiaoming Maybe you can confirm that my understanding of your work is correct?

senz1024 commented 3 months ago

@Shiaoming Thank you very much for the clarification. It is now clear.

@PierreCarceller This is very helpful, thank you very much!