ZYK100 / LLCM

[CVPR 2023] Diverse Embedding Expansion Network and Low-Light Cross-Modality Benchmark for Visible-Infrared Person Re-identification
https://github.com/ZYK100/LLCM/blob/main/Agreement/LLCM%20DATASET%20RELEASE%20AGREEMENT.pdf
106 stars 12 forks source link

DEE_module代码 #22

Closed vk-rrr closed 5 months ago

vk-rrr commented 8 months ago

class DEE_module(nn.Module): def init(self, channel, reduction=16): super(DEE_module, self).init()

    self.FC11 = nn.Conv2d(channel, channel // 4, kernel_size=3, stride=1, padding=1, bias=False, dilation=1)
    self.FC11.apply(weights_init_kaiming)
    self.FC12 = nn.Conv2d(channel, channel // 4, kernel_size=3, stride=1, padding=2, bias=False, dilation=2)
    self.FC12.apply(weights_init_kaiming)
    self.FC13 = nn.Conv2d(channel, channel // 4, kernel_size=3, stride=1, padding=3, bias=False, dilation=3)
    self.FC13.apply(weights_init_kaiming)
    self.FC1 = nn.Conv2d(channel // 4, channel, kernel_size=1)
    self.FC1.apply(weights_init_kaiming)

    self.FC21 = nn.Conv2d(channel, channel // 4, kernel_size=3, stride=1, padding=1, bias=False, dilation=1)
    self.FC21.apply(weights_init_kaiming)
    self.FC22 = nn.Conv2d(channel, channel // 4, kernel_size=3, stride=1, padding=2, bias=False, dilation=2)
    self.FC22.apply(weights_init_kaiming)
    self.FC23 = nn.Conv2d(channel, channel // 4, kernel_size=3, stride=1, padding=3, bias=False, dilation=3)
    self.FC23.apply(weights_init_kaiming)
    self.FC2 = nn.Conv2d(channel // 4, channel, kernel_size=1)
    self.FC2.apply(weights_init_kaiming)

    self.dropout = nn.Dropout(p=0.01)

def forward(self, x):
    x1 = (self.FC11(x) + self.FC12(x) + self.FC13(x)) / 3
    x1 = self.FC1(F.relu(x1))
    x2 = (self.FC21(x) + self.FC22(x) + self.FC23(x)) / 3
    x2 = self.FC2(F.relu(x2))
    out = torch.cat((x, x1, x2), 0)
    out = self.dropout(out)
    return out

请问,代码中这里是不是只算两个分支呀? 论文中,说明了Three branches

ZYK100 commented 8 months ago

原始分支也算一个分支

vk-rrr commented 8 months ago

好的,谢谢解答

vk-rrr commented 8 months ago

class embed_net(nn.Module): def init(self, class_num, dataset, arch='resnet50'): super(embed_net, self).init()

    self.thermal_module = thermal_module(arch=arch)
    self.visible_module = visible_module(arch=arch)
    self.base_resnet = base_resnet(arch=arch)

    self.dataset = dataset
    if self.dataset == 'regdb': 
        pool_dim = 1024
        self.DEE = DEE_module(512)
        self.MFA1 = MFA_block(256, 64, 0)
        self.MFA2 = MFA_block(512, 256, 1)
    else:
        pool_dim = 2048
        self.DEE = DEE_module(1024)
        self.MFA1 = MFA_block(256, 64, 0)
        self.MFA2 = MFA_block(512, 256, 1)
        self.MFA3 = MFA_block(1024, 512, 1)

您好,请问代码这里只有3个MFA模块,但是论文Figure 3.中有4个MFA模块,一共有Stage0-4五个阶段,应该怎么理解?