MCC-WH / SSP

Official implementation of paper "Structure Similarity Preservation Learning for Asymmetric Image Retrieval"
8 stars 1 forks source link

Issues about running "Ours_training.py" and "Ours_Asys_testing.py" #3

Open SeunghanYu opened 7 months ago

SeunghanYu commented 7 months ago

Hi, @MCC-WH.

I've been working with Ours_training.py and encountered a couple of issues that prevented the code from executing properly, necessitating some modifications on my end.

First Issue:

In the WarmupCos_Scheduler class, specifically within the def step(self): method, I encountered an issue where the lr_schedule list's index exceeded 1. To address this, I made the following modifications to the code:

def step(self):
    current_lr = self.lr_schedule[self.iter]
    for param_group in self.optimizer.param_groups:
        param_group['lr'] = current_lr
    self.iter += 1
    return current_lr

Second Issue:

In the main function, it appeared that anchor_features was being unnecessarily retrieved in the loop:

for idx, (images, features, anchor_features) in enumerate(metric_logger.log_every(train_loader, print_freq, header)):

Therefore, I adjusted the code as follows:

for idx, (images, features) in enumerate(val_metric_logger.log_every(val_loader, print_freq, '>> Val Epoch: [{}]'.format(epoch))):
    ...
    distill = model(images, features)
    ...

After making these adjustments, the training proceeded without further issues.

However, upon completing the training and moving on to Ours_Asys_testing.py, I encountered another problem when using the R101-DELG.pth file you shared on the google drive.

Issue:

When attempting to load the state dictionary with:

db_net.load_state_dict(os.path.join(get_data_root(), 'R101-DELG.pth'), strict=True)

I found that there were keys present in the loaded state_dict that did not exist in db_net: Unexpected in loaded state_dict: {'attention.att_conv1.weight', 'reduction.weight', 'reduction.bias'}

After removing these keys, I ran the code again but found that the mAP value was 4 (very low)

Having gone through these modifications to run the provided code, I am currently seeing unsatisfactory results. I'm wondering if there's an updated version of the code available that addresses these issues. Additionally, I would like to know if there might be a problem with the provided R101-DELG.pth file.

Looking forward to your reply. Thank you in advance!

MCC-WH commented 7 months ago

I apologize for the errors, it should be my oversight when organizing the code. The correct DELG network structure definition should be found here DELG.

MCC-WH commented 7 months ago

These days I'm in an important internship and I don't have time to fix bugs in the code for a while, so if you run into some issues you could ask me again and I'll get back to you when I see it. :)

SeunghanYu commented 7 months ago

Hello, @MCC-WH.

I modified the SSP/networks/R101-DELG.py. (based on Token/networks/RetrievalNet.py)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch import Tensor, nn
from .pooling import GeM
from utils import resnet_block_dilation

import math

eps_fea_norm = 1e-5
eps_l2_norm = 1e-10

class ArcFace(nn.Module):
    def __init__(self, in_features, out_features, s=64.0, m=0.50, eps=1e-6):
        super(ArcFace, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.eps = eps

        self.s = s
        self.m = m

        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
        self.threshold = math.pi - self.m

    def forward(self, input, label):
        cos_theta = F.linear(F.normalize(input, dim=-1), F.normalize(self.weight, dim=-1))
        theta = torch.acos(torch.clamp(cos_theta, -1.0 + self.eps, 1.0 - self.eps))

        one_hot = torch.zeros(cos_theta.size()).to(input.device)
        one_hot.scatter_(1, label.view(-1, 1), 1)

        selected = torch.where(theta > self.threshold, torch.zeros_like(one_hot), one_hot).bool()

        output = torch.cos(torch.where(selected, theta + self.m, theta))
        output *= self.s
        return output

class ResNet(nn.Module):
    def __init__(self, name: str, train_backbone: bool, dilation_block5: bool):
        super(ResNet, self).__init__()
        net_in = getattr(torchvision.models, name)(pretrained=True)
        if name.startswith('resnet'):
            features = list(net_in.children())[:-2]
        else:
            raise ValueError('Unsupported or unknown architecture: {}!'.format(name))
        features = nn.Sequential(*features)
        self.outputdim_block5 = 2048
        self.outputdim_block4 = 1024
        self.block1 = features[:4]
        self.block2 = features[4]
        self.block3 = features[5]
        self.block4 = features[6]
        self.block5 = features[7]
        if dilation_block5:
            self.block5 = resnet_block_dilation(self.block5, dilation=2)
        if not train_backbone:
            for param in self.parameters():
                param.requires_grad_(False)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        return x

class ResNet_STAGE45(nn.Module):
    def __init__(self, name: str, train_backbone: bool, dilation_block5: bool):
        super(ResNet_STAGE45, self).__init__()
        net_in = getattr(torchvision.models, name)(pretrained=True)
        if name.startswith('resnet'):
            features = list(net_in.children())[:-2]
        else:
            raise ValueError('Unsupported or unknown architecture: {}!'.format(name))
        features = nn.Sequential(*features)
        self.outputdim_block5 = 2048
        self.outputdim_block4 = 1024
        self.block1 = features[:4]
        self.block2 = features[4]
        self.block3 = features[5]
        self.block4 = features[6]
        self.block5 = features[7]
        if dilation_block5:
            self.block5 = resnet_block_dilation(self.block5, dilation=2)
        if not train_backbone:
            for param in self.parameters():
                param.requires_grad_(False)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x4 = self.block4(x)
        x5 = self.block5(x4)
        return x4, x5

class Spatial_Attention(nn.Module):
    def __init__(self, input_dim):
        super(Spatial_Attention, self).__init__()
        self.att_conv1 = nn.Conv2d(input_dim, 1, kernel_size=(1, 1), padding=0, stride=1, bias=False)
        self.att_act2 = nn.Softplus(beta=1, threshold=20)
        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, x):
        att_score = self.att_act2(self.att_conv1(x))
        return att_score

class DELG(nn.Module):
    def __init__(self, outputdim=2048, reduction_dim=128, classifier_num=1024):
        super(DELG, self).__init__()

        self.backbone = ResNet_STAGE45(name='resnet101', train_backbone=True, dilation_block5=False)
        self.pooling = GeM(p=3.0)
        self.whiten = nn.Conv2d(self.backbone.outputdim_block5, outputdim, kernel_size=(1, 1), stride=1, padding=0, bias=True)
        self.outputdim = outputdim
        self.classifier_block5 = ArcFace(in_features=outputdim, out_features=classifier_num, s=math.sqrt(self.outputdim), m=0.2)
        self.classifier_block4 = ArcFace(in_features=reduction_dim, out_features=classifier_num, s=math.sqrt(reduction_dim), m=0.1)
        self.attention = Spatial_Attention(input_dim=1024)
        self.reduction = nn.Conv2d(self.backbone.outputdim_block4, reduction_dim, kernel_size=1, padding=0, stride=1, bias=True)

    def _init_input_proj(self, weight, bias):
        self.reduction.weight.data = weight
        self.reduction.bias.data = bias

    @torch.no_grad()
    def forward_test(self, x):
        x4, x5 = self.backbone(x)
        global_feature = F.normalize(self.pooling(x5), p=2.0, dim=1)
        global_feature = self.whiten(global_feature).squeeze(-1).squeeze(-1)
        global_feature = F.normalize(global_feature, p=2.0, dim=-1)
        return global_feature

    def forward(self, x):
        x4, x5 = self.backbone(x)
        global_feature = F.normalize(self.pooling(x5), p=2.0, dim=1)
        global_feature = self.whiten(global_feature).squeeze(-1).squeeze(-1)
        global_feature = F.normalize(global_feature, p=2.0, dim=-1)
        return global_feature

After making these modifications, I reran the entire process from the beginning.

Upon executing Ours_Asys_testing.py, I observed a significant drop in performance, as detailed below:

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00,  7.85it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 4993/4993 [13:59<00:00,  5.95it/s]
>> Test Dataset: roxford5k *** Feature Type: GeM asys >>
>> whiten: mAP Medium: 4.79, Hard: 6.26
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:06<00:00, 11.51it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 6322/6322 [18:11<00:00,  5.79it/s]
>> Test Dataset: rparis6k *** Feature Type: GeM asys >>
>> whiten: mAP Medium: 2.55, Hard: 5.4

These results were obtained using timm=0.9.16. I'm curious if there are known performance issues related to specific versions of timm. Could the version of timm be causing this performance degradation?

Additionally, I would like to know if there are any plans to distribute the modified code after your internship concludes.

Looking forward to your reply. Thank you in advance!