Jinghe-mel / UFEN-SLAM

MIT License
15 stars 2 forks source link

Hello, could you provide the training code for UFEN's binary descriptors? I've tried to replicate the training code for binary descriptors but the results have been poor. Therefore, I am seeking your help. #4

Open bumblebee15138 opened 5 months ago

Jinghe-mel commented 5 months ago

Hi,

Adding a stronger noise term (N) in image synthesis (separately for the paired images) enhances the performance of the descriptor. You should consider trying it. We are not going to provide the whole training code at this stage. But if you need further help on that, I can provide you the "Matching loss" days later.

bumblebee15138 commented 5 months ago

Hi,

Adding a stronger noise term (N) in image synthesis (separately for the paired images) enhances the performance of the descriptor. You should consider trying it. We are not going to provide the whole training code at this stage. But if you need further help on that, I can provide you the "Matching loss" days later.

Thank you for your response. I will try again following your suggestions. It would be great if you could provide the "Matching loss" as well.

bumblebee15138 commented 5 months ago

Hi,

Adding a stronger noise term (N) in image synthesis (separately for the paired images) enhances the performance of the descriptor. You should consider trying it. We are not going to provide the whole training code at this stage. But if you need further help on that, I can provide you the "Matching loss" days later.

I apologize for bothering you again. I've been having trouble replicating good results in the training part of binary descriptor. Could you please provide the code for the 'Matching loss' section? I appreciate it greatly.

Jinghe-mel commented 5 months ago

Hi, Sorry for the late update. I have uploaded the matching implementation code and the weights for the fast/easy implementation and comparisons. I will also update the "Matching Loss" code for you this weekend.

bumblebee15138 commented 5 months ago

Hi, Sorry for the late update. I have uploaded the matching implementation code and the weights for the fast/easy implementation and comparisons. I will also update the "Matching Loss" code for you this weekend.

Thank you very much for your kind help. I sincerely hope that you can achieve even more brilliant academic achievements.

Jinghe-mel commented 5 months ago

No worries, you are very welcome. Please find the attached code. Hope it will be helpful.

class STE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return torch.sign(input)
    @staticmethod
    def backward(ctx, grad_outputs):
        return grad_outputs.clamp_(-1, 1)

def Matching_loss(pts, t_pts, desc, t_desc):
    # example matching loss for a single image pair:
    # pts, t_pts are the matching points on the paired images, respectively. (1 * N * 2)
    # N - number of points, 2 - pixel location, e.g. (485, 155).
    # desc, t_desc are the descriptor outputs of the models, in shape: (1 * 256, 60, 80) for input image (480, 640)

    def get_mask(kp0, kp1, dist_thresh):
        batch_size, num_points, _ = kp0.size()
        dist_kp0 = torch.norm(kp0.unsqueeze(2) - kp0.unsqueeze(1), dim=-1)
        dist_kp1 = torch.norm(kp1.unsqueeze(2) - kp1.unsqueeze(1), dim=-1)
        min_dist = torch.min(dist_kp0, dist_kp1)
        dist_mask = min_dist <= dist_thresh
        dist_mask = dist_mask.repeat(1, 1, batch_size).reshape(batch_size * num_points, batch_size * num_points)
        return dist_mask

    def desc_obtain(pts, desc):
        _, _, Hc, Wc = desc.shape
        samp_pts = pts.squeeze().transpose(1, 0)
        samp_pts[0, :] = (samp_pts[0, :] / (float(Wc * 8) / 2.)) - 1.
        samp_pts[1, :] = (samp_pts[1, :] / (float(Hc * 8) / 2.)) - 1.
        samp_pts = samp_pts.transpose(0, 1).contiguous()
        samp_pts = samp_pts.view(1, 1, -1, 2)
        samp_pts = samp_pts.float()
        tpts_desc = torch.nn.functional.grid_sample(desc, samp_pts, align_corners=True)

        # pts_desc = torch.reshape(tpts_desc, (-1, 256))
        pts_desc = tpts_desc.squeeze().transpose(0, 1)
        pts_desc = torch.nn.functional.normalize(pts_desc, dim=1)
        return pts_desc

    ste_sign = STE.apply
    dist_mask = get_mask(pts, t_pts, 8)   # T = 8, check the detected points are too close.

    pts_desc = desc_obtain(pts, desc)  # get the float descriptors of the points.
    t_pts_desc = desc_obtain(t_pts, t_desc)
    pts_desc_bin = ste_sign(pts_desc).type(torch.float)  # binarization (keep the gradient).
    t_pts_desc_bin = ste_sign(t_pts_desc).type(torch.float)
    b_dis = 128 - (pts_desc_bin @ t_pts_desc_bin.t()) / 2.0
    b_match_dis = torch.diag(b_dis)

    b_match_dis = b_match_dis.unsqueeze(dim=1)
    b_match_dis = torch.max(torch.zeros_like(b_match_dis), b_match_dis - 0.1 * 256)  # P = 0.1 * 256

    b_dis[dist_mask] = 256
    b_non_ids = b_dis
    b_non_ids = torch.min(torch.min(b_non_ids, dim=1)[0], torch.min(b_non_ids, dim=0)[0])
    b_non_ids = torch.max(torch.zeros_like(b_non_ids), -b_non_ids + 0.5 * 256)  # Q = 0.5 * 256
    b_tri_loss = torch.square(b_match_dis) + torch.square(b_non_ids)
    blosses = torch.mean(b_tri_loss) * 0.0001  # alpha = 0.0001
    return blosses
bumblebee15138 commented 5 months ago

Hello, I would like to ask how many epochs you trained the network to achieve the expected effect? Thank you.

Jinghe-mel commented 5 months ago

Hello, I would like to ask how many epochs you trained the network to achieve the expected effect? Thank you.

The proposed weights are trained over 20 epochs. Typically, performance nearly converges after 10 epochs. If additional training or more epochs are necessary, you can incorporate a small similarity loss (e.g., L2 loss) between the descriptor outputs. This helps ensure that the new model remains close to the original SuperPoint.