bfshi / scaling_on_scales

When do we not need larger vision models?
MIT License
277 stars 9 forks source link

How to add multiscale_forward functionality to my model #7

Open 1chenchen22 opened 4 months ago

1chenchen22 commented 4 months ago

How to add multiscale_forward functionality to my model. My code is written according to the example you gave me.

from s2wrapper import forward as multiscale_forward model = DDAMNet(num_class=7, num_head=args.num_head) for epoch in tqdm(range(1, args.epochs + 1)): iter_cnt = 0 model.train()

for (imgs, targets) in train_loader: iter_cnt += 1 optimizer.zero_grad()

imgs = imgs.to(device) targets = targets.to(device) print("Model:", model) print("Input shape:", imgs.shape)

multiscale_feature = multiscale_forward(model, imgs, scales=[1, 2], num_prefix_token=1)

Extract multi-scale features

multiscale_feature = multiscale_forward(model, imgs, scales=[1, 2],num_prefix_token=1)

Combine multi-scale features with other features

combined_feature = torch.cat((imgs, multiscale_feature), dim=1)

Input the combined features into the model

out, feat, heads = model(combined_feature) The output model is the number of layers of the model I defined, and the size of the imgs is ([64, 3, 112, 112]).

1chenchen22 commented 4 months ago

Writing code as above will result in an error Traceback (most recent call last): File "/home/sd2022012521/project/DDAMFN-main/10-s2wrapper.py", line 308, in run_training() File "/home/sd2022012521/project/DDAMFN-main/10-s2wrapper.py", line 193, in run_training multiscale_feature = multiscale_forward(model, imgs, scales=[1, 2], num_prefix_token=1) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/sd2022012521/.virtualenvs/DDAMFN-main/lib/python3.11/site-packages/s2wrapper/core.py", line 40, in forward outs_prefix_multiscale = [out[:, :num_prefix_token] for out in outs_multiscale] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/sd2022012521/.virtualenvs/DDAMFN-main/lib/python3.11/site-packages/s2wrapper/core.py", line 40, in outs_prefix_multiscale = [out[:, :num_prefix_token] for out in outs_multiscale]


TypeError: tuple indices must be integers or slices, not tuple
bfshi commented 4 months ago

Hi! For the model you use to extract multi-scale features:

model = DDAMNet(num_class=7, num_head=args.num_head)

What is the format of the output of this model? Is it a PyTorch tensor or something else like a tuple? For example, if I simply extract single-scale features:

single_scale_feautres = model(imgs)

how would single_scale_features look like?

1chenchen22 commented 4 months ago

hello,The output of model = DDAMNet(num_class=7, num_head=args.num_head) is a tuple (out, x, head_out) where out is the final classification result and x is the feature tensor. head_out is the list of header outputs.This seems to be the reason why an error is reported in the outs_multiscale = [model(x) for x in input_multiscale] line

1chenchen22 commented 4 months ago

Hi, I get the output with single_scale_feautres = model(imgs). single_scale_feautres is a tuple of multiple elements. The first and second elements are tensors, and the third element is a list.Can I still use the features in your article in my model Thank you for your patience

1chenchen22 commented 4 months ago

Hi, I get the output with single_scale_feautres = model(imgs). single_scale_feautres is a tuple of multiple elements. The first and second elements are tensors, and the third element is a list.Can I still use the features in your article in my model Thank you for your patience

Thank you, I have successfully solved the above problem

1chenchen22 commented 4 months ago

@bfshi Hello, I try to add the functions in the article to my face expression recognition model, but the effect is not good, the accuracy of recognition is lower, the original is 90%, now the highest is only 74%, the convergence speed is also slower, training 40 rounds, with a longer time. First use the MFN layer in the diagram, the MixedFeatureNet class, to extract the multi-scale features and put them into a model like the one shown in the diagram

微信图片_20240330153841

Here is part of the training code

net = MixedFeatureNet31.MixedFeatureNet()

net = torch.load(os.path.join('./pretrained/', "MFN_msceleb.pth")) model2 = nn.Sequential(*list(net.children())[:-4]) model2 = model2.to(device)

best_acc = 0 for epoch in tqdm(range(1, args.epochs + 1)): running_loss = 0.0 correct_sum = 0 iter_cnt = 0 model.train() for (imgs, targets) in train_loader: iter_cnt += 1 optimizer.zero_grad() imgs = imgs.to(device) targets = targets.to(device) multiscale_feature = multiscale_forward(model2, imgs, scales=[1, 2], num_prefix_token=0) print(" multiscale_feature", multiscale_feature.shape) out, feat, heads = model(multiscale_feature)

Here is the code of the model, the original number of channels obtained by my feature extraction is 512, but the number of channels changed with multi-feature extraction, so it was changed to 1024.


class DDAMNet(nn.Module):
def init(self, num_class=7, num_head=2, pretrained=True):
super(DDAMNet, self).init()
self.num_head = num_head
for i in range(int(num_head)):
setattr(self, "cat_head%d" % (i), CoordAttHead())

#self.Linear = Linear_block(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0))
self.Linear = Linear_block(1024, 1024, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0))
self.flatten = Flatten()
self.fc = nn.Linear(1024, num_class)

self.fc = nn.Linear(1024, num_class)
self.bn = nn.BatchNorm1d(num_class)
def forward(self, multiscale_feature):
heads = []
for i in range(self.num_head):
heads.append(getattr(self, "cat_head%d" % i)(multiscale_feature))
head_out = heads
y = heads[0]
for i in range(1, self.num_head):
y = torch.max(y, heads[i])
print("yshape",y.shape)
y = multiscale_feature * y
y = self.Linear(y)
y = self.flatten(y)
out = self.fc(y)
return out, multiscale_feature, head_out

I am a novice, so I do not understand very much, hope to get your answer, thank you very much

bfshi commented 4 months ago

Hi,

In the code you showed above, it seems only model is being optimized and model2 is not being optimized? I wonder if that's the reason for the worse performance.

Could you explain why we are defining model2? If I understand correctly, in the original code, model (which is a DDAMNet) contains two parts: feature extraction part (MFN) and Attention module part (DDA heads). In the new code, we divide the whole DDAMNet into two separate models,model2 which is the feature extraction module and model which is the attention module, in order to use multiscale_forward on only the feature extraction module (model2)?

If that is the case, I think we don't need to use two separate models. We can just use the original definition of DDAMNet (which contains both feature extraction and attention module), and put multiscale_forward inside the forward function of DDAMNet. Here's an example:

=======================
# Original DDAMNet

Class DDAMNet(nn.Module):
...
    def forward(self, images):
        # feature extraction part
        ...
        features = ...  # here features is a tensor

        # attention module
        ...
        head_out = some_function(features)
        out = some_funtion(features)

========================
# Multiscale DDAMNet

class DDAMNetMS(nn.Module):
...
    def feature_extraction(self, images):
        # feature extraction part
        ...
        features = ...  # here features is a tensor
        return features

    def forward(self, images):
        # multi-scale feature extraction
        multiscale_features = multiscale_forward(self.feature_extraction, images, scales=[1, 2], num_prefix_token=0)

        # attention module
        ...
        head_out = some_function(multiscale_features)
        out = some_funtion(multiscale_features) 

==========================

Here we wrap the feature extraction part into a function called feature_extraction, and we can use multiscale_forward in forward to call self.feature_extraction.

Let me know if this is helpful!

1chenchen22 commented 4 months ago

您好,是的,您理解的是正确的,(DDAMNet)包含两部分:特征提取部分(MFN)和注意力模块部分(DDA头),这个模型是我拿的别人的直接用的,就是这个https://github.com/simon20010923/DDAMFN/tree/main

应该就是您上面提到的原因,我按照您说的这样修改,

class DDAMNet(nn.Module):
def __init__(self, num_class=7, num_head=2, pretrained=True):
super(DDAMNet, self).__init__()
net = MixedFeatureNet.MixedFeatureNet()
#
if pretrained:
net = torch.load(os.path.join('./pretrained/', "MFN_msceleb.pth"))

self.feature_extraction = nn.Sequential(*list(net.children())[:-4])

self.num_head = num_head
for i in range(int(num_head)):
setattr(self, "cat_head%d" % (i), CoordAttHead())

#self.Linear = Linear_block(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0))
self.Linear = Linear_block(1024, 1024, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0))
self.flatten = Flatten()
self.fc = nn.Linear(1024, num_class)
# self.fc = nn.Linear(1024, num_class)
self.bn = nn.BatchNorm1d(num_class)

def forward(self, imgs):
multiscale_feature = multiscale_forward(self.feature_extraction, imgs, scales=[1, 2], num_prefix_token=0)
heads = []

for i in range(self.num_head):
heads.append(getattr(self, "cat_head%d" % i)(multiscale_feature))
head_out = heads

y = heads[0]

for i in range(1, self.num_head):
y = torch.max(y, heads[i])
#print("yshape",y.shape)
y = multiscale_feature * y
y = self.Linear(y)
y = self.flatten(y)
out = self.fc(y)
return out, multiscale_feature, head_out

得到的准确率能够达到89%,和我不添加您文中的多尺度特征提取,效果差不多。,我正在继续查看修改代码,看看如果改变学习率的策略,会不会变好, 感谢大佬的回答和帮助

bfshi commented 4 months ago

另外一个问题是,我注意到self.feature_extraction得到的输出应该是BxCxHxW的,这种情况下需要在multiscale_forward的时候设置output_shape='bchw', 否则会报错。但我不太清楚为什么你这里没有设置也没报错。你可以check一下self.feature_extraction输出的tensor的形状是什么么

1chenchen22 commented 4 months ago

对,self.feature_extraction就是那个 MixedFeatureNet()去掉后四层的模型,是一个 Sequential 对象 然后如果把imgs放到这个模型去做特征提取,得到的最后的是bchw形状的张量, 然后,我把您s2wrapper文件夹里的core.py的函数修改了,原来是 output_shape='bnc',我改成这样了

def forward(model, input, scales=None, img_sizes=None, max_split_size=None, resize_output_to_idx=0, num_prefix_token=0,
            output_shape='bchw')
1chenchen22 commented 3 months ago

我又训练了一次,加了这个也没有提点,准确率没有变高,

bfshi commented 3 months ago

加之前和之后的training loss,test loss,以及accuracy分别是多少呢? 以及,MFN和Attention Module都在训练么,还是说MFN是freeze住的,只有Attention Module在更新?

1chenchen22 commented 3 months ago

我把MFN和Attention两个部分都放到一起了,都放到了class DDAMNet类里,然后去训练,所以,是两个部分都训练更新了,

model = DDAMNet(num_class=7, num_head=args.num_head)
        for epoch in tqdm(range(1, args.epochs + 1)):
        running_loss = 0.0
        correct_sum = 0
        iter_cnt = 0
        model.train()

        for (imgs, targets) in train_loader:
            iter_cnt += 1
            optimizer.zero_grad()

            imgs = imgs.to(device)
            targets = targets.to(device)

            out, feat, heads = model(imgs)

然后,我训练了四次,第一次没有把mfn训练部分加进去,就效果不好,然后根据您的提示做了修改 1711956523150

下面是第三次训练,第三次是加了您文章里的s2wrapper,还有别的文章的featup。https://github.com/mhamilton723/FeatUp就是这个。还用了第二次训练得到的权重文件 得到的结果训练集的准确率很高, 1711957802931 这里是第二次, 1711957793933

我就考虑过拟合了,是不是要改改学习率,做了第四次,去掉了featup, 这是今天的训练的日志文件, 日志文件太长了,本来想画图,让你看,但是我这里画图一直报错,

2024-04-01 11:50 INFO [Epoch 1] Training accuracy: 0.6613. Loss: 4.048 2024-04-01 11:51 INFO [Epoch 1] Validation accuracy: 0.7832. Loss: 2.556 2024-04-01 11:51 INFO Best_acc:0.7832 2024-04-01 11:53 INFO [Epoch 2] Training accuracy: 0.7790. Loss: 2.342 2024-04-01 11:53 INFO [Epoch 2] Validation accuracy: 0.8067. Loss: 2.016 2024-04-01 11:53 INFO Best_acc:0.8067 2024-04-01 11:55 INFO [Epoch 3] Training accuracy: 0.8179. Loss: 2.037 2024-04-01 11:55 INFO [Epoch 3] Validation accuracy: 0.8299. Loss: 1.911 2024-04-01 11:55 INFO Best_acc:0.8299 2024-04-01 11:57 INFO [Epoch 4] Training accuracy: 0.8247. Loss: 1.913 2024-04-01 11:57 INFO [Epoch 4] Validation accuracy: 0.7891. Loss: 1.912 2024-04-01 11:57 INFO Best_acc:0.8299 2024-04-01 12:00 INFO [Epoch 5] Training accuracy: 0.8393. Loss: 1.801 2024-04-01 12:00 INFO [Epoch 5] Validation accuracy: 0.8302. Loss: 1.806 2024-04-01 12:00 INFO Best_acc:0.8302 2024-04-01 12:02 INFO [Epoch 6] Training accuracy: 0.8275. Loss: 1.810 2024-04-01 12:02 INFO [Epoch 6] Validation accuracy: 0.8201. Loss: 1.857 2024-04-01 12:02 INFO Best_acc:0.8302 2024-04-01 12:04 INFO [Epoch 7] Training accuracy: 0.8551. Loss: 1.679 2024-04-01 12:04 INFO [Epoch 7] Validation accuracy: 0.8374. Loss: 1.773 2024-04-01 12:04 INFO Best_acc:0.8374 2024-04-01 12:06 INFO [Epoch 8] Training accuracy: 0.8662. Loss: 1.621 2024-04-01 12:07 INFO [Epoch 8] Validation accuracy: 0.8458. Loss: 1.685 2024-04-01 12:07 INFO Best_acc:0.8458 2024-04-01 12:09 INFO [Epoch 9] Training accuracy: 0.8897. Loss: 1.530 2024-04-01 12:09 INFO [Epoch 9] Validation accuracy: 0.8422. Loss: 1.693 2024-04-01 12:09 INFO Best_acc:0.8458 2024-04-01 12:11 INFO [Epoch 10] Training accuracy: 0.8855. Loss: 1.527 2024-04-01 12:11 INFO [Epoch 10] Validation accuracy: 0.8449. Loss: 1.648 2024-04-01 12:11 INFO Best_acc:0.8458 2024-04-01 12:13 INFO [Epoch 11] Training accuracy: 0.9010. Loss: 1.479 2024-04-01 12:13 INFO [Epoch 11] Validation accuracy: 0.8354. Loss: 1.692 2024-04-01 12:13 INFO Best_acc:0.8458 2024-04-01 12:16 INFO [Epoch 12] Training accuracy: 0.9096. Loss: 1.428 2024-04-01 12:16 INFO [Epoch 12] Validation accuracy: 0.8550. Loss: 1.655 2024-04-01 12:16 INFO Best_acc:0.855 2024-04-01 12:18 INFO [Epoch 13] Training accuracy: 0.9188. Loss: 1.398 2024-04-01 12:18 INFO [Epoch 13] Validation accuracy: 0.8497. Loss: 1.683 2024-04-01 12:18 INFO Best_acc:0.855 2024-04-01 12:20 INFO [Epoch 14] Training accuracy: 0.9200. Loss: 1.393 2024-04-01 12:20 INFO [Epoch 14] Validation accuracy: 0.8462. Loss: 1.632 2024-04-01 12:20 INFO Best_acc:0.855 2024-04-01 12:22 INFO [Epoch 15] Training accuracy: 0.9311. Loss: 1.344 2024-04-01 12:23 INFO [Epoch 15] Validation accuracy: 0.8585. Loss: 1.616 2024-04-01 12:23 INFO Best_acc:0.8585 2024-04-01 12:25 INFO [Epoch 16] Training accuracy: 0.9372. Loss: 1.319 2024-04-01 12:25 INFO [Epoch 16] Validation accuracy: 0.8638. Loss: 1.626 2024-04-01 12:25 INFO Best_acc:0.8638 2024-04-01 12:27 INFO [Epoch 17] Training accuracy: 0.9481. Loss: 1.275 2024-04-01 12:27 INFO [Epoch 17] Validation accuracy: 0.8634. Loss: 1.596 2024-04-01 12:27 INFO Best_acc:0.8638 2024-04-01 12:29 INFO [Epoch 18] Training accuracy: 0.9495. Loss: 1.265 2024-04-01 12:29 INFO [Epoch 18] Validation accuracy: 0.8664. Loss: 1.590 2024-04-01 12:29 INFO Best_acc:0.8664 2024-04-01 12:32 INFO [Epoch 19] Training accuracy: 0.9566. Loss: 1.237 2024-04-01 12:32 INFO [Epoch 19] Validation accuracy: 0.8722. Loss: 1.566 2024-04-01 12:32 INFO Best_acc:0.8722 2024-04-01 12:34 INFO [Epoch 20] Training accuracy: 0.9616. Loss: 1.220 2024-04-01 12:34 INFO [Epoch 20] Validation accuracy: 0.8693. Loss: 1.570 2024-04-01 12:34 INFO Best_acc:0.8722 2024-04-01 12:36 INFO [Epoch 21] Training accuracy: 0.9598. Loss: 1.220 2024-04-01 12:36 INFO [Epoch 21] Validation accuracy: 0.8680. Loss: 1.602 2024-04-01 12:36 INFO Best_acc:0.8722 2024-04-01 12:39 INFO [Epoch 22] Training accuracy: 0.9641. Loss: 1.197 2024-04-01 12:39 INFO [Epoch 22] Validation accuracy: 0.8677. Loss: 1.599 2024-04-01 12:39 INFO Best_acc:0.8722 2024-04-01 12:41 INFO [Epoch 23] Training accuracy: 0.9683. Loss: 1.185 2024-04-01 12:41 INFO [Epoch 23] Validation accuracy: 0.8709. Loss: 1.590 2024-04-01 12:41 INFO Best_acc:0.8722 2024-04-01 12:43 INFO [Epoch 24] Training accuracy: 0.9655. Loss: 1.185 2024-04-01 12:43 INFO [Epoch 24] Validation accuracy: 0.8641. Loss: 1.616 2024-04-01 12:43 INFO Best_acc:0.8722 2024-04-01 12:45 INFO [Epoch 25] Training accuracy: 0.9695. Loss: 1.167 2024-04-01 12:46 INFO [Epoch 25] Validation accuracy: 0.8742. Loss: 1.563 2024-04-01 12:46 INFO Best_acc:0.8742 2024-04-01 12:48 INFO [Epoch 26] Training accuracy: 0.9732. Loss: 1.159 2024-04-01 12:48 INFO [Epoch 26] Validation accuracy: 0.8699. Loss: 1.556 2024-04-01 12:48 INFO Best_acc:0.8742 2024-04-01 12:50 INFO [Epoch 27] Training accuracy: 0.9709. Loss: 1.157 2024-04-01 12:50 INFO [Epoch 27] Validation accuracy: 0.8709. Loss: 1.550 2024-04-01 12:50 INFO Best_acc:0.8742 2024-04-01 12:52 INFO [Epoch 28] Training accuracy: 0.9720. Loss: 1.150 2024-04-01 12:52 INFO [Epoch 28] Validation accuracy: 0.8628. Loss: 1.605 2024-04-01 12:52 INFO Best_acc:0.8742 2024-04-01 12:55 INFO [Epoch 29] Training accuracy: 0.9767. Loss: 1.138 2024-04-01 12:55 INFO [Epoch 29] Validation accuracy: 0.8683. Loss: 1.585 2024-04-01 12:55 INFO Best_acc:0.8742 2024-04-01 12:57 INFO [Epoch 30] Training accuracy: 0.9811. Loss: 1.121 2024-04-01 12:57 INFO [Epoch 30] Validation accuracy: 0.8550. Loss: 1.621 2024-04-01 12:57 INFO Best_acc:0.8742 2024-04-01 12:59 INFO [Epoch 31] Training accuracy: 0.9808. Loss: 1.121 2024-04-01 12:59 INFO [Epoch 31] Validation accuracy: 0.8748. Loss: 1.533 2024-04-01 12:59 INFO Best_acc:0.8748 2024-04-01 13:01 INFO [Epoch 32] Training accuracy: 0.9821. Loss: 1.110 2024-04-01 13:02 INFO [Epoch 32] Validation accuracy: 0.8677. Loss: 1.593 2024-04-01 13:02 INFO Best_acc:0.8748 2024-04-01 13:04 INFO [Epoch 33] Training accuracy: 0.9809. Loss: 1.108 2024-04-01 13:04 INFO [Epoch 33] Validation accuracy: 0.8748. Loss: 1.550 2024-04-01 13:04 INFO Best_acc:0.8748 2024-04-01 13:06 INFO [Epoch 34] Training accuracy: 0.9844. Loss: 1.101 2024-04-01 13:06 INFO [Epoch 34] Validation accuracy: 0.8745. Loss: 1.560 2024-04-01 13:06 INFO Best_acc:0.8748 2024-04-01 13:08 INFO [Epoch 35] Training accuracy: 0.9844. Loss: 1.099 2024-04-01 13:09 INFO [Epoch 35] Validation accuracy: 0.8827. Loss: 1.551 2024-04-01 13:09 INFO Best_acc:0.8827 2024-04-01 13:11 INFO [Epoch 36] Training accuracy: 0.9839. Loss: 1.095 2024-04-01 13:11 INFO [Epoch 36] Validation accuracy: 0.8605. Loss: 1.600 2024-04-01 13:11 INFO Best_acc:0.8827 2024-04-01 13:13 INFO [Epoch 37] Training accuracy: 0.9844. Loss: 1.087 2024-04-01 13:13 INFO [Epoch 37] Validation accuracy: 0.8791. Loss: 1.539 2024-04-01 13:13 INFO Best_acc:0.8827 2024-04-01 13:15 INFO [Epoch 38] Training accuracy: 0.9870. Loss: 1.081 2024-04-01 13:15 INFO [Epoch 38] Validation accuracy: 0.8761. Loss: 1.549 2024-04-01 13:15 INFO Best_acc:0.8827 2024-04-01 13:18 INFO [Epoch 39] Training accuracy: 0.9852. Loss: 1.081 2024-04-01 13:18 INFO [Epoch 39] Validation accuracy: 0.8774. Loss: 1.537 2024-04-01 13:18 INFO Best_acc:0.8827 2024-04-01 13:20 INFO [Epoch 40] Training accuracy: 0.9866. Loss: 1.076 2024-04-01 13:20 INFO [Epoch 40] Validation accuracy: 0.8771. Loss: 1.534 2024-04-01 13:20 INFO Best_acc:0.8827 相比原来模型,这是原来模型不添加您文章里的s2wrapper得到的结果,Training accuracy: 最高能达到0.9805,Training accuracy: Best_acc:0.9016,Loss: 0.425, 就是说原来的loss低,Training accuracy差不多,Validation accuracy:差了查不对1,2个点吧,还行吧,但是训练时间变长了很多,

https://zhuanlan.zhihu.com/p/689205084知乎这篇文章,您看了吗

这个模型他挺复杂的MFN里面的基础架构是MobileFaceNet。在网络的瓶颈中还加了MixConv操作,还有坐标关注机制。是不是因为这个模型太复杂太大了,而且这个模型里面还有注意力机制,所以您提出的这个s2wrapper可能不适合我的这个任务吧?

我再继续看看吧,等服务器空闲了,我再试试,但是也不知道要怎么再去试了, 可能再去寻找别的创新点吧,

总之谢谢您这几天的回答和帮助,大佬的文章很牛,很厉害,可能不适合我的这个任务吧,感谢您耐心看完哈哈哈

bfshi commented 3 months ago

谢谢你的回复! 有个事情确认一下,就是你说不加s2wrapper的话training loss是0.425,加了之后training loss是1.076?然后同时加s2wrapper和featup的training loss在1.114左右?那感觉还是挺奇怪的,因为0.425和1.076差别还是挺大的,而且按理来讲加了s2wrapper就算过拟合的话training loss至少应该更低。不知道是什么原因?

1chenchen22 commented 3 months ago

对,是这样的,不加s2wrapper的话training loss是0.425,加了之后training loss是1.076,然后同时加s2wrapper和featup的training loss在1.114左右,

image 然后我改学习率调整的策略,是这样 image

class AffinityLoss(nn.Module):
    def __init__(self, device, num_class=7, feat_dim=1024):
        ###调整聚类中心矩阵的形状原来是512
        super(AffinityLoss, self).__init__()
        self.num_class = num_class
        self.feat_dim = feat_dim
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.device = device

        self.centers = nn.Parameter(torch.randn(self.num_class, self.feat_dim).to(device))

    def forward(self, x, labels):
        x = self.gap(x).view(x.size(0), -1)

        batch_size = x.size(0)
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_class) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_class, batch_size).t()
        distmat.addmm_(x, self.centers.t(), beta=1, alpha=-2)

        classes = torch.arange(self.num_class).long().to(self.device)
        labels = labels.unsqueeze(1).expand(batch_size, self.num_class)
        mask = labels.eq(classes.expand(batch_size, self.num_class))

        dist = distmat * mask.float()
        dist = dist / self.centers.var(dim=0).sum()

        loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size

        return loss

就是加了affinity Loss 使得同一类别内的样本更加紧凑(减小类内差距)并增大不同类别之间的差距(增大类间差距

不加这个affinity Loss训练40轮,training acc是0.9876,验证集的Best_acc是0.9051, Loss: 0.477。。 我加了这个affinity Loss后训练40轮,training acc是 0.9923,Best_acc:0.9061,Loss: 1.189。 那你说我是加这个好,还是不加好啊,呜呜呜呜,我还以为我加的这个affinity Loss是好的,有效的,提点了,都没关注loss,