gathierry / FastFlow

Apache License 2.0
134 stars 24 forks source link

Unmatched model A.d. parameter #2

Open gathierry opened 2 years ago

gathierry commented 2 years ago

The model additional parameter number cannot match Table 1. in the paper.

wide-resnet-50 resnet18 DeiT CaiT
paper 41.3M 4.9M 14.8M 14.8M
this implem 45.0M 5.6M 7.1M 7.1M
mjack3 commented 2 years ago

You can check my code. The implementation i have match the exact parameters

+---------------------------------+------------+ | Modules | Parameters | +---------------------------------+------------+ | 0.module_list.0.global_scale | 256 | | 0.module_list.0.global_offset | 256 | | 0.module_list.0.subnet.0.weight | 147456 | | 0.module_list.0.subnet.0.bias | 128 | | 0.module_list.0.subnet.2.weight | 294912 | | 0.module_list.0.subnet.2.bias | 256 | | 0.module_list.1.global_scale | 256 | | 0.module_list.1.global_offset | 256 | | 0.module_list.1.subnet.0.weight | 16384 | | 0.module_list.1.subnet.0.bias | 128 | | 0.module_list.1.subnet.2.weight | 32768 | | 0.module_list.1.subnet.2.bias | 256 | | 0.module_list.2.global_scale | 256 | | 0.module_list.2.global_offset | 256 | | 0.module_list.2.subnet.0.weight | 147456 | | 0.module_list.2.subnet.0.bias | 128 | | 0.module_list.2.subnet.2.weight | 294912 | | 0.module_list.2.subnet.2.bias | 256 | | 0.module_list.3.global_scale | 256 | | 0.module_list.3.global_offset | 256 | | 0.module_list.3.subnet.0.weight | 16384 | | 0.module_list.3.subnet.0.bias | 128 | | 0.module_list.3.subnet.2.weight | 32768 | | 0.module_list.3.subnet.2.bias | 256 | | 0.module_list.4.global_scale | 256 | | 0.module_list.4.global_offset | 256 | | 0.module_list.4.subnet.0.weight | 147456 | | 0.module_list.4.subnet.0.bias | 128 | | 0.module_list.4.subnet.2.weight | 294912 | | 0.module_list.4.subnet.2.bias | 256 | | 0.module_list.5.global_scale | 256 | | 0.module_list.5.global_offset | 256 | | 0.module_list.5.subnet.0.weight | 16384 | | 0.module_list.5.subnet.0.bias | 128 | | 0.module_list.5.subnet.2.weight | 32768 | | 0.module_list.5.subnet.2.bias | 256 | | 0.module_list.6.global_scale | 256 | | 0.module_list.6.global_offset | 256 | | 0.module_list.6.subnet.0.weight | 147456 | | 0.module_list.6.subnet.0.bias | 128 | | 0.module_list.6.subnet.2.weight | 294912 | | 0.module_list.6.subnet.2.bias | 256 | | 0.module_list.7.global_scale | 256 | | 0.module_list.7.global_offset | 256 | | 0.module_list.7.subnet.0.weight | 16384 | | 0.module_list.7.subnet.0.bias | 128 | | 0.module_list.7.subnet.2.weight | 32768 | | 0.module_list.7.subnet.2.bias | 256 | | 1.module_list.0.global_scale | 512 | | 1.module_list.0.global_offset | 512 | | 1.module_list.0.subnet.0.weight | 589824 | | 1.module_list.0.subnet.0.bias | 256 | | 1.module_list.0.subnet.2.weight | 1179648 | | 1.module_list.0.subnet.2.bias | 512 | | 1.module_list.1.global_scale | 512 | | 1.module_list.1.global_offset | 512 | | 1.module_list.1.subnet.0.weight | 65536 | | 1.module_list.1.subnet.0.bias | 256 | | 1.module_list.1.subnet.2.weight | 131072 | | 1.module_list.1.subnet.2.bias | 512 | | 1.module_list.2.global_scale | 512 | | 1.module_list.2.global_offset | 512 | | 1.module_list.2.subnet.0.weight | 589824 | | 1.module_list.2.subnet.0.bias | 256 | | 1.module_list.2.subnet.2.weight | 1179648 | | 1.module_list.2.subnet.2.bias | 512 | | 1.module_list.3.global_scale | 512 | | 1.module_list.3.global_offset | 512 | | 1.module_list.3.subnet.0.weight | 65536 | | 1.module_list.3.subnet.0.bias | 256 | | 1.module_list.3.subnet.2.weight | 131072 | | 1.module_list.3.subnet.2.bias | 512 | | 1.module_list.4.global_scale | 512 | | 1.module_list.4.global_offset | 512 | | 1.module_list.4.subnet.0.weight | 589824 | | 1.module_list.4.subnet.0.bias | 256 | | 1.module_list.4.subnet.2.weight | 1179648 | | 1.module_list.4.subnet.2.bias | 512 | | 1.module_list.5.global_scale | 512 | | 1.module_list.5.global_offset | 512 | | 1.module_list.5.subnet.0.weight | 65536 | | 1.module_list.5.subnet.0.bias | 256 | | 1.module_list.5.subnet.2.weight | 131072 | | 1.module_list.5.subnet.2.bias | 512 | | 1.module_list.6.global_scale | 512 | | 1.module_list.6.global_offset | 512 | | 1.module_list.6.subnet.0.weight | 589824 | | 1.module_list.6.subnet.0.bias | 256 | | 1.module_list.6.subnet.2.weight | 1179648 | | 1.module_list.6.subnet.2.bias | 512 | | 1.module_list.7.global_scale | 512 | | 1.module_list.7.global_offset | 512 | | 1.module_list.7.subnet.0.weight | 65536 | | 1.module_list.7.subnet.0.bias | 256 | | 1.module_list.7.subnet.2.weight | 131072 | | 1.module_list.7.subnet.2.bias | 512 | | 2.module_list.0.global_scale | 1024 | | 2.module_list.0.global_offset | 1024 | | 2.module_list.0.subnet.0.weight | 2359296 | | 2.module_list.0.subnet.0.bias | 512 | | 2.module_list.0.subnet.2.weight | 4718592 | | 2.module_list.0.subnet.2.bias | 1024 | | 2.module_list.1.global_scale | 1024 | | 2.module_list.1.global_offset | 1024 | | 2.module_list.1.subnet.0.weight | 262144 | | 2.module_list.1.subnet.0.bias | 512 | | 2.module_list.1.subnet.2.weight | 524288 | | 2.module_list.1.subnet.2.bias | 1024 | | 2.module_list.2.global_scale | 1024 | | 2.module_list.2.global_offset | 1024 | | 2.module_list.2.subnet.0.weight | 2359296 | | 2.module_list.2.subnet.0.bias | 512 | | 2.module_list.2.subnet.2.weight | 4718592 | | 2.module_list.2.subnet.2.bias | 1024 | | 2.module_list.3.global_scale | 1024 | | 2.module_list.3.global_offset | 1024 | | 2.module_list.3.subnet.0.weight | 262144 | | 2.module_list.3.subnet.0.bias | 512 | | 2.module_list.3.subnet.2.weight | 524288 | | 2.module_list.3.subnet.2.bias | 1024 | | 2.module_list.4.global_scale | 1024 | | 2.module_list.4.global_offset | 1024 | | 2.module_list.4.subnet.0.weight | 2359296 | | 2.module_list.4.subnet.0.bias | 512 | | 2.module_list.4.subnet.2.weight | 4718592 | | 2.module_list.4.subnet.2.bias | 1024 | | 2.module_list.5.global_scale | 1024 | | 2.module_list.5.global_offset | 1024 | | 2.module_list.5.subnet.0.weight | 262144 | | 2.module_list.5.subnet.0.bias | 512 | | 2.module_list.5.subnet.2.weight | 524288 | | 2.module_list.5.subnet.2.bias | 1024 | | 2.module_list.6.global_scale | 1024 | | 2.module_list.6.global_offset | 1024 | | 2.module_list.6.subnet.0.weight | 2359296 | | 2.module_list.6.subnet.0.bias | 512 | | 2.module_list.6.subnet.2.weight | 4718592 | | 2.module_list.6.subnet.2.bias | 1024 | | 2.module_list.7.global_scale | 1024 | | 2.module_list.7.global_offset | 1024 | | 2.module_list.7.subnet.0.weight | 262144 | | 2.module_list.7.subnet.0.bias | 512 | | 2.module_list.7.subnet.2.weight | 524288 | | 2.module_list.7.subnet.2.bias | 1024 | +---------------------------------+------------+ Total Trainable Params: 41.34 M

I think in your case could be by using timm backbone

gathierry commented 2 years ago

@mjack3 I was able to match WideResNet50 as well. You see 45.0M here is because I added NormLayers. I'm pretty sure it shouldn't be like this but I cannot reach comparable result without them. Besides, if you replace wideresnet50 with resnet18 or one of the transformers, can you still match the parameters?

gathierry commented 2 years ago

@mjack3 BTW, timm shouldn't be a problem since the backbone is fixed and not counted in "additional params"

mjack3 commented 2 years ago

Hello @gathierry

Using ResNet18 I match 2.6M (2.7M paper) using 3-1 and 4.7M (4.9) using 3-3

Obscure..it's a light difference that make me think that AllInOneBlock is not what we need

questionstorer commented 2 years ago

For the model with WideResNet50 as feature extractor, there are 8 flow step. Each flow step should have 2 groups of Conv2D-RELU-Conv2D. But in the flow step implemented here, it looks like every flow step has an AllInOneBlock block which only has one group of Conv2D-RELU-Conv2D. Is this understanding correct? Is this going to have an impact on the number of parameters?

mjack3 commented 2 years ago

@questionstorer Currently, AllInOneBlock is the only way to match the A.d. x1 hidden channel. You are correct, here we just have one group of Conv2D-Relu-Conv2D.

FastFlow paper has not been accepted yet in any journal or conference. So we only can trust in the idea presented.

gathierry commented 2 years ago

@questionstorer nice catch and that's something that confused me as well. If we have 2 groups in each step then the parameter number is doubled. The number for DeiT and CaiT are closer to paper but for resnet the difference will be even larger.

Zigars commented 2 years ago

@gathierry Hi, I reconstruct your fastflow code, and my wide_resnet50_2 have the 41.33M(paper:41.3M) A.D. Param, and resnet18 have 4.65M(paper:4.9M) A.D. Param. Also the cait and deit have the same A.D. Param as your code.(7.07M. paper:14.8M), and My wide_resnet50_2 have the LayerNorm like yours.

gathierry commented 2 years ago

@Zigars thanks for the feedback, but how do you manage to reduce wrn50 from 45M to 41.3M without removing LayerNorm? Which part did you update?

Zigars commented 2 years ago

@gathierry I just seperate the model to encoder(feature_extractor) and decoder(fastflow A.D.) like c-flow, just calculate the decoder's A.D. Param in model loading, and I get the right 41.3M in wrn50 to match the paper's Param. Maybe your concat model have some modules do not set param.requires_grad = False?

Also, in my own code, I added the image_level auc calculate module, and I'm testing resnet18 on MVTec, this cost some time in training. In feature, I will also add visulize module in testing and predict.
Thank you for your open-source code, I learned a lot from your code!

gathierry commented 2 years ago

@Zigars so I guess you put LayerNorm in the encoder? I count it in A.D. params as well since the original wrn50 has no layer norm. I also tried to set elementwise_affine=False to remove their learnable parameters only to find the final AUC dropped. Please correct me if you have different observations.

Zigars commented 2 years ago

@gathierry Yes, I put the LayerNorm in the encoder, maybe the original paper also did this. because without LayerNorm, the decoder(FastFlow) can match the paper's A.D. Params. After all, the paper do not have the officials code, we can try it by ourselves.