Pongpisit-Thanasutives / Variations-of-SFANet-for-Crowd-Counting

The official implementation of "Encoder-Decoder Based Convolutional Neural Networks with Multi-Scale-Aware Modules for Crowd Counting"
https://ieeexplore.ieee.org/document/9413286
GNU General Public License v3.0
108 stars 32 forks source link

could not reproduce the res in the paper #19

Open knightyxp opened 3 years ago

knightyxp commented 3 years ago

the published code do not have CANet branch, not coordinated with the paper report, the baseline of SFANet is 59 according to my exp, however, when I add ASPP(means using the M-SFANet model according to the author's code) the SHA MAE is only 61, ridiculous, the res in the paper can not be reproduced (i do not know whether the reviewer of ICPR know this thing)

Pongpisit-Thanasutives commented 3 years ago

@knightyxp In the ScalePyramidModule class, defined in M-SFANet.py, there is self.can = ContextualModule(512, 512) as the CANet branch. If the CAN module is deducted, the reported performance was 62.41 (MAE) on SHA and 7.40 (MAE) on SHB. Please ensure to include the module in your forward pass as well.

knightyxp commented 3 years ago

whether this means just add can in sfan is not suitable

knightyxp commented 3 years ago

what is more, in your exp M-SFANet w/o CAN(this means just sfan) the MAE on SHA is 62.41, however the sfan on sha could achieve 59(my exp) reported 60 , so i do not know how u get this res

截屏2021-05-21 上午12 49 12
Pongpisit-Thanasutives commented 3 years ago

@knightyxp I see. M-SFANet w/o CAN means deducting self.can in ScalePyramidModule. Can I have your training code? It's hard to notice the difference in implementation.

knightyxp commented 3 years ago

same as train.py in sfanet

Pongpisit-Thanasutives commented 3 years ago

Thank you. I have seen your codes and spotted some inconsistencies in your implementation:

(1) I did not use the SSM loss. (2) I did not use Adam. In the SHA experiment, I used LookaheadAdam(model.parameters(), lr=5e-4) (See ./models). (3) Please train up to 1000 epochs, not 500, because the M-SFANet model is more complex in terms of #params. (4) The reproduced weights of M_SFANet (SHA) with MAE=59.69 and MSE=95.64 are provided via the google drive link. So you can check the saved epoch and the optimizer's state_dict.

P.S. I cannot check your preprocessing code, which is also important.

knightyxp commented 3 years ago

can not load your pretrain_pth error like this: Traceback (most recent call last): File "test.py", line 40, in model.load_state_dict(torch.load(model_path), device) File "/opt/conda/envs/torch17/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1052, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Model: Missing key(s) in state_dict: "vgg.conv1_1.conv.weight", "vgg.conv1_1.conv.bias", "vgg.conv1_1.bn.weight", "vgg.conv1_1.bn.bias", "vgg.conv1_1.bn.running_mean", "vgg.conv1_1.bn.running_var", "vgg.conv1_2.conv.weight", "vgg.conv1_2.conv.bias", "vgg.conv1_2.bn.weight", "vgg.conv1_2.bn.bias", "vgg.conv1_2.bn.running_mean", "vgg.conv1_2.bn.running_var", "vgg.conv2_1.conv.weight", "vgg.conv2_1.conv.bias", "vgg.conv2_1.bn.weight", "vgg.conv2_1.bn.bias", "vgg.conv2_1.bn.running_mean", "vgg.conv2_1.bn.running_var", "vgg.conv2_2.conv.weight", "vgg.conv2_2.conv.bias", "vgg.conv2_2.bn.weight", "vgg.conv2_2.bn.bias", "vgg.conv2_2.bn.running_mean", "vgg.conv2_2.bn.running_var", "vgg.conv3_1.conv.weight", "vgg.conv3_1.conv.bias", "vgg.conv3_1.bn.weight", "vgg.conv3_1.bn.bias", "vgg.conv3_1.bn.running_mean", "vgg.conv3_1.bn.running_var", "vgg.conv3_2.conv.weight", "vgg.conv3_2.conv.bias", "vgg.conv3_2.bn.weight", "vgg.conv3_2.bn.bias", "vgg.conv3_2.bn.running_mean", "vgg.conv3_2.bn.running_var", "vgg.conv3_3.conv.weight", "vgg.conv3_3.conv.bias", "vgg.conv3_3.bn.weight", "vgg.conv3_3.bn.bias", "vgg.conv3_3.bn.running_mean", "vgg.conv3_3.bn.running_var", "vgg.conv4_1.conv.weight", "vgg.conv4_1.conv.bias", "vgg.conv4_1.bn.weight", "vgg.conv4_1.bn.bias", "vgg.conv4_1.bn.running_mean", "vgg.conv4_1.bn.running_var", "vgg.conv4_2.conv.weight", "vgg.conv4_2.conv.bias", "vgg.conv4_2.bn.weight", "vgg.conv4_2.bn.bias", "vgg.conv4_2.bn.running_mean", "vgg.conv4_2.bn.running_var", "vgg.conv4_3.conv.weight", "vgg.conv4_3.conv.bias", "vgg.conv4_3.bn.weight", "vgg.conv4_3.bn.bias", "vgg.conv4_3.bn.running_mean", "vgg.conv4_3.bn.running_var", "vgg.conv5_1.conv.weight", "vgg.conv5_1.conv.bias", "vgg.conv5_1.bn.weight", "vgg.conv5_1.bn.bias", "vgg.conv5_1.bn.running_mean", "vgg.conv5_1.bn.running_var", "vgg.conv5_2.conv.weight", "vgg.conv5_2.conv.bias", "vgg.conv5_2.bn.weight", "vgg.conv5_2.bn.bias", "vgg.conv5_2.bn.running_mean", "vgg.conv5_2.bn.running_var", "vgg.conv5_3.conv.weight", "vgg.conv5_3.conv.bias", "vgg.conv5_3.bn.weight", "vgg.conv5_3.bn.bias", "vgg.conv5_3.bn.running_mean", "vgg.conv5_3.bn.running_var", "spm.assp.aspp1.atrous_conv.weight", "spm.assp.aspp1.bn.weight", "spm.assp.aspp1.bn.bias", "spm.assp.aspp1.bn.running_mean", "spm.assp.aspp1.bn.running_var", "spm.assp.aspp2.atrous_conv.weight", "spm.assp.aspp2.bn.weight", "spm.assp.aspp2.bn.bias", "spm.assp.aspp2.bn.running_mean", "spm.assp.aspp2.bn.running_var", "spm.assp.aspp3.atrous_conv.weight", "spm.assp.aspp3.bn.weight", "spm.assp.aspp3.bn.bias", "spm.assp.aspp3.bn.running_mean", "spm.assp.aspp3.bn.running_var", "spm.assp.aspp4.atrous_conv.weight", "spm.assp.aspp4.bn.weight", "spm.assp.aspp4.bn.bias", "spm.assp.aspp4.bn.running_mean", "spm.assp.aspp4.bn.running_var", "spm.assp.global_avg_pool.1.weight", "spm.assp.global_avg_pool.2.weight", "spm.assp.global_avg_pool.2.bias", "spm.assp.global_avg_pool.2.running_mean", "spm.assp.global_avg_pool.2.running_var", "spm.assp.conv1.weight", "spm.assp.bn1.weight", "spm.assp.bn1.bias", "spm.assp.bn1.running_mean", "spm.assp.bn1.running_var", "spm.can.scales.0.1.weight", "spm.can.scales.1.1.weight", "spm.can.scales.2.1.weight", "spm.can.scales.3.1.weight", "spm.can.bottleneck.weight", "spm.can.bottleneck.bias", "spm.can.weight_net.weight", "spm.can.weight_net.bias", "amp.conv1.conv.weight", "amp.conv1.conv.bias", "amp.conv1.bn.weight", "amp.conv1.bn.bias", "amp.conv1.bn.running_mean", "amp.conv1.bn.running_var", "amp.conv2.conv.weight", "amp.conv2.conv.bias", "amp.conv2.bn.weight", "amp.conv2.bn.bias", "amp.conv2.bn.running_mean", "amp.conv2.bn.running_var", "amp.conv3.conv.weight", "amp.conv3.conv.bias", "amp.conv3.bn.weight", "amp.conv3.bn.bias", "amp.conv3.bn.running_mean", "amp.conv3.bn.running_var", "amp.conv4.conv.weight", "amp.conv4.conv.bias", "amp.conv4.bn.weight", "amp.conv4.bn.bias", "amp.conv4.bn.running_mean", "amp.conv4.bn.running_var", "amp.conv5.conv.weight", "amp.conv5.conv.bias", "amp.conv5.bn.weight", "amp.conv5.bn.bias", "amp.conv5.bn.running_mean", "amp.conv5.bn.running_var", "amp.conv6.conv.weight", "amp.conv6.conv.bias", "amp.conv6.bn.weight", "amp.conv6.bn.bias", "amp.conv6.bn.running_mean", "amp.conv6.bn.running_var", "amp.conv7.conv.weight", "amp.conv7.conv.bias", "amp.conv7.bn.weight", "amp.conv7.bn.bias", "amp.conv7.bn.running_mean", "amp.conv7.bn.running_var", "dmp.conv1.conv.weight", "dmp.conv1.conv.bias", "dmp.conv1.bn.weight", "dmp.conv1.bn.bias", "dmp.conv1.bn.running_mean", "dmp.conv1.bn.running_var", "dmp.conv2.conv.weight", "dmp.conv2.conv.bias", "dmp.conv2.bn.weight", "dmp.conv2.bn.bias", "dmp.conv2.bn.running_mean", "dmp.conv2.bn.running_var", "dmp.conv3.conv.weight", "dmp.conv3.conv.bias", "dmp.conv3.bn.weight", "dmp.conv3.bn.bias", "dmp.conv3.bn.running_mean", "dmp.conv3.bn.running_var", "dmp.conv4.conv.weight", "dmp.conv4.conv.bias", "dmp.conv4.bn.weight", "dmp.conv4.bn.bias", "dmp.conv4.bn.running_mean", "dmp.conv4.bn.running_var", "dmp.conv5.conv.weight", "dmp.conv5.conv.bias", "dmp.conv5.bn.weight", "dmp.conv5.bn.bias", "dmp.conv5.bn.running_mean", "dmp.conv5.bn.running_var", "dmp.conv6.conv.weight", "dmp.conv6.conv.bias", "dmp.conv6.bn.weight", "dmp.conv6.bn.bias", "dmp.conv6.bn.running_mean", "dmp.conv6.bn.running_var", "dmp.conv7.conv.weight", "dmp.conv7.conv.bias", "dmp.conv7.bn.weight", "dmp.conv7.bn.bias", "dmp.conv7.bn.running_mean", "dmp.conv7.bn.running_var", "conv_att.conv.weight", "conv_att.conv.bias", "conv_att.bn.weight", "conv_att.bn.bias", "conv_att.bn.running_mean", "conv_att.bn.running_var", "conv_out.conv.weight", "conv_out.conv.bias", "conv_out.bn.weight", "conv_out.bn.bias", "conv_out.bn.running_mean", "conv_out.bn.running_var". Unexpected key(s) in state_dict: "epoch", "model", "optimizer", "mae", "mse".

Pongpisit-Thanasutives commented 3 years ago

@knightyxp The weights are stored in the torch.load(model_path)["model"]. Like this one.

Screen Shot 2564-05-25 at 20 40 33