verlab / accelerated_features

Implementation of XFeat (CVPR 2024). Do you need robust and fast local feature extraction? You are in the right place!
https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24
Apache License 2.0
1.01k stars 113 forks source link

revise self._unfold2d(x, ws=8) ? #52

Open longzeyilang opened 3 months ago

longzeyilang commented 3 months ago

HI, I trained my own data. image size about 128*128, and change model` self.block1 = nn.Sequential( BasicLayer( 1, 8, stride=1), BasicLayer( 8, 24, stride=1), BasicLayer( 24, 64, stride=1), ) self.block2 = nn.Sequential( BasicLayer(64, 64, stride=2), BasicLayer(64, 64, stride=1), BasicLayer(64, 64, stride=1), )

    self.block3 = nn.Sequential(
                                    BasicLayer( 64, 128, stride=2),
                                    BasicLayer(128, 128, stride=1),
                                    BasicLayer(128, 128, stride=1),
                                    BasicLayer(128,  64, 1, padding=0),
                                 )

    self.block_fusion =  nn.Sequential(
                                    BasicLayer(64, 64, stride=1),
                                    BasicLayer(64, 64, stride=1),
                                    nn.Conv2d (64, 64, 1, padding=0)
                                 )

    self.heatmap_head = nn.Sequential(
                                    BasicLayer(64, 64, 1, padding=0),
                                    BasicLayer(64, 64, 1, padding=0),
                                    nn.Conv2d (64, 1, 1),
                                    nn.Sigmoid()
                                )

    self.keypoint_head = nn.Sequential(
                                    BasicLayer(4, 64, 1, padding=0),
                                    BasicLayer(64, 64, 1, padding=0),
                                    BasicLayer(64, 64, 1, padding=0),
                                    nn.Conv2d (64, 5, 1),
                                )

and forward change as follow: def forward(self, x): """ input: x -> torch.Tensor(B, C, H, W) grayscale or rgb images return: feats -> torch.Tensor(B, 64, H/8, W/8) dense local features keypoints -> torch.Tensor(B, 65, H/8, W/8) keypoint logit map heatmap -> torch.Tensor(B, 1, H/8, W/8) reliability map

    """
    #dont backprop through normalization
    with torch.no_grad():
        x = x.mean(dim=1, keepdim = True)
        x = self.norm(x)

    #main backbone
    x1 = self.block1(x)
    x2 = self.block2(x1)
    x3 = self.block3(x2)
    x4 = F.interpolate(x3, (x2.shape[-2], x2.shape[-1]), mode='bilinear')
    feats = self.block_fusion(x4 + x2)

    #heads
    heatmap = self.heatmap_head(feats)                        # Reliability map
    keypoints = self.keypoint_head(self._unfold2d(x, ws=2))   # Keypoint map logits
    return feats, keypoints, heatmap`

the unflod2d ws change to 2, how to revise keypoint_head ? and how to revise losses.py? thank you

longzeyilang commented 3 months ago

`import torch import torch.nn.functional as F

from modules.dataset.megadepth import megadepth_warper

from modules.training import utils

from third_party.alike_wrapper import extract_alike_kpts

from modules.model_small import UNFLOD_WS

def dual_softmax_loss(X, Y, temp = 0.2): if X.size() != Y.size() or X.dim() != 2 or Y.dim() != 2: raise RuntimeError('Error: X and Y shapes must match and be 2D matrices')

dist_mat = (X @ Y.t()) * temp
conf_matrix12 = F.log_softmax(dist_mat, dim=1)
conf_matrix21 = F.log_softmax(dist_mat.t(), dim=1)

with torch.no_grad():
    conf12 = torch.exp( conf_matrix12 ).max(dim=-1)[0]
    conf21 = torch.exp( conf_matrix21 ).max(dim=-1)[0]
    conf = conf12 * conf21

target = torch.arange(len(X), device = X.device)

loss = F.nll_loss(conf_matrix12, target) + \
       F.nll_loss(conf_matrix21, target)

return loss, conf

def smooth_l1_loss(input, target, beta=2.0, size_average=True): diff = torch.abs(input - target) loss = torch.where(diff < beta, 0.5 * diff * 2 / beta, diff - 0.5 beta) return loss.mean() if size_average else loss.sum()

def fine_loss(f1, f2, pts1, pts2, fine_module, ws=7): ''' Compute Fine features and spatial loss ''' C, H, W = f1.shape N = len(pts1)

#Sort random offsets
with torch.no_grad():
    a = -(ws//2)
    b = (ws//2)
    offset_gt = (a - b) * torch.rand(N, 2, device = f1.device) + b
    pts2_random = pts2 + offset_gt

#pdb.set_trace()
patches1 = utils.crop_patches(f1.unsqueeze(0), (pts1+0.5).long(), size=ws).view(C, N, ws * ws).permute(1, 2, 0) #[N, ws*ws, C]
patches2 = utils.crop_patches(f2.unsqueeze(0), (pts2_random+0.5).long(), size=ws).view(C, N, ws * ws).permute(1, 2, 0)  #[N, ws*ws, C]

#Apply transformer
patches1, patches2 = fine_module(patches1, patches2)

features = patches1.view(N, ws, ws, C)[:, ws//2, ws//2, :].view(N, 1, 1, C) # [N, 1, 1, C]
patches2 = patches2.view(N, ws, ws, C) # [N, w, w, C]

#Dot Product
heatmap_match = (features * patches2).sum(-1)
offset_coords = utils.subpix_softmax2d(heatmap_match)

#Invert offset because center crop inverts it
offset_gt = -offset_gt 

#MSE
error = ((offset_coords - offset_gt)**2).sum(-1).mean()
return error

def alike_distill_loss(kpts, img): C, H, W = kpts.shape kpts = kpts.permute(1,2,0) img = img.permute(1,2,0).expand(-1,-1,3).cpu().numpy() * 255

with torch.no_grad():
    alike_kpts = torch.tensor(extract_alike_kpts(img), device=kpts.device)
    labels = torch.ones((H, W), dtype = torch.long, device = kpts.device) * UNFLOD_WS*UNFLOD_WS # -> Default is non-keypoint (bin 64)
    offsets = (((alike_kpts/UNFLOD_WS) - (alike_kpts/UNFLOD_WS).long())*UNFLOD_WS).long()
    offsets =  offsets[:, 0] + UNFLOD_WS*offsets[:, 1]  # Linear IDX
    labels[(alike_kpts[:,1]/UNFLOD_WS).long(), (alike_kpts[:,0]/UNFLOD_WS).long()] = offsets

kpts = kpts.view(-1,C)
labels = labels.view(-1)

mask = labels < UNFLOD_WS*UNFLOD_WS
idxs_pos = mask.nonzero().flatten()
idxs_neg = (~mask).nonzero().flatten()
perm = torch.randperm(idxs_neg.size(0))[:len(idxs_pos)//32]
idxs_neg = idxs_neg[perm]
idxs = torch.cat([idxs_pos, idxs_neg])

kpts = kpts[idxs]
labels = labels[idxs]

with torch.no_grad():
    predicted = kpts.max(dim=-1)[1]
    acc =  (labels == predicted)
    acc = acc.sum() / len(acc)

kpts = F.log_softmax(kpts)
loss = F.nll_loss(kpts, labels, reduction = 'mean')

return loss, acc

def keypoint_position_loss(kpts1, kpts2, pts1, pts2, softmax_temp = 1.0): ''' Computes coordinate classification loss, by re-interpreting the 64 bins to 8x8 grid and optimizing for correct offsets ''' C, H, W = kpts1.shape kpts1 = kpts1.permute(1,2,0) softmax_temp kpts2 = kpts2.permute(1,2,0) softmax_temp

with torch.no_grad():
    #Generate meshgrid
    x, y = torch.meshgrid(torch.arange(W, device=kpts1.device), torch.arange(H, device=kpts1.device), indexing ='xy')
    xy = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1)], dim=-1)
    xy*=8

    #Generate collision map
    hashmap = torch.ones((H*UNFLOD_WS, W*UNFLOD_WS, 2), dtype = torch.long, device = kpts1.device) * -1
    hashmap[(pts1[:,1]).long(), (pts1[:,0]).long(), :] = (pts2).long()

    #Estimate offset of src kpts 
    _, kpts1_offsets = kpts1.max(dim=-1)
    kpts1_offsets_x = kpts1_offsets  % UNFLOD_WS
    kpts1_offsets_y = kpts1_offsets // UNFLOD_WS
    kpts1_offsets_xy = torch.cat([kpts1_offsets_x.unsqueeze(-1), 
                                  kpts1_offsets_y.unsqueeze(-1)], dim=-1)
    #pdb.set_trace()
    kpts1_coords = xy + kpts1_offsets_xy

    #find src -> tgt pts
    kpts1_coords = kpts1_coords.view(-1,2)
    gt_12 = hashmap[kpts1_coords[:,1], kpts1_coords[:,0]]
    mask_valid = torch.all(gt_12 >= 0, dim=-1)
    gt_12 = gt_12[mask_valid]

    #find offset labels
    labels2 = (gt_12/UNFLOD_WS) - (gt_12/UNFLOD_WS).long()
    labels2 = (labels2 * UNFLOD_WS).long()
    labels2 = labels2[:, 0] + UNFLOD_WS*labels2[:, 1] #linear index

kpts2_selected = kpts2[(gt_12[:, 1]/UNFLOD_WS).long(), (gt_12[:, 0]/UNFLOD_WS).long()]        

kpts1_selected = F.log_softmax(kpts1.view(-1,C)[mask_valid], dim=-1)
kpts2_selected = F.log_softmax(kpts2_selected, dim=-1)

#Here we enforce softmax to keep current max on src kps
with torch.no_grad():
    _, labels1 =  kpts1_selected.max(dim=-1)

predicted2 = kpts2_selected.max(dim=-1)[1]
acc =  (labels2 == predicted2)
acc = acc.sum() / len(acc)

loss = F.nll_loss(kpts1_selected, labels1, reduction = 'mean') + \
       F.nll_loss(kpts2_selected, labels2, reduction = 'mean')

return loss, acc

def coordinate_classification_loss(coords1, pts1, pts2, conf): ''' Computes the fine coordinate classification loss, by re-interpreting the 64 bins to 8x8 grid and optimizing for correct offsets after warp '''

Do not backprop coordinate warps

with torch.no_grad():
    coords1_detached = pts1 * UNFLOD_WS 
    #find offset
    offsets1_detached = (coords1_detached/UNFLOD_WS) - (coords1_detached/UNFLOD_WS).long()
    offsets1_detached = (offsets1_detached * UNFLOD_WS).long()
    labels1 = offsets1_detached[:, 0] + UNFLOD_WS*offsets1_detached[:, 1]

#pdb.set_trace()
coords1_log = F.log_softmax(coords1, dim=-1)

predicted = coords1.max(dim=-1)[1]
acc =  (labels1 == predicted)
acc = acc[conf > 0.1]
acc = acc.sum() / len(acc)

loss = F.nll_loss(coords1_log, labels1, reduction = 'none')

#Weight loss by confidence, giving more emphasis on reliable matches
conf = conf / conf.sum()
loss = (loss * conf).sum()

return loss * 2., acc

def keypoint_loss(heatmap, target):

Compute L1 loss

L1_loss = F.l1_loss(heatmap, target)
return L1_loss * 3.0

def hard_triplet_loss(X,Y, margin = 0.5):

if X.size() != Y.size() or X.dim() != 2 or Y.dim() != 2:
    raise RuntimeError('Error: X and Y shapes must match and be 2D matrices')

dist_mat = torch.cdist(X, Y, p=2.0)
dist_pos = torch.diag(dist_mat)
dist_neg = dist_mat + 100.*torch.eye(*dist_mat.size(), dtype = dist_mat.dtype, 
        device = dist_mat.get_device() if dist_mat.is_cuda else torch.device("cpu"))
#filter repeated patches on negative distances to avoid weird stuff on gradients
dist_neg = dist_neg + dist_neg.le(0.01).float()*100.

#Margin Ranking Loss
hard_neg = torch.min(dist_neg, 1)[0]
loss = torch.clamp(margin + dist_pos - hard_neg, min=0.)
return loss.mean()

` the loss file and UNFLOD_WS=2, please check

longzeyilang commented 3 months ago

@guipotje

guipotje commented 3 months ago

Hi @longzeyilang,

After a quick review, it seems your updates are in theory correct, the only problem I see is that a 2x2 patch provides too little context for the keypoint head to be effective.

What kind of issues are you experiencing?