Open gathierry opened 2 years ago
You can check my code. The implementation i have match the exact parameters
+---------------------------------+------------+ | Modules | Parameters | +---------------------------------+------------+ | 0.module_list.0.global_scale | 256 | | 0.module_list.0.global_offset | 256 | | 0.module_list.0.subnet.0.weight | 147456 | | 0.module_list.0.subnet.0.bias | 128 | | 0.module_list.0.subnet.2.weight | 294912 | | 0.module_list.0.subnet.2.bias | 256 | | 0.module_list.1.global_scale | 256 | | 0.module_list.1.global_offset | 256 | | 0.module_list.1.subnet.0.weight | 16384 | | 0.module_list.1.subnet.0.bias | 128 | | 0.module_list.1.subnet.2.weight | 32768 | | 0.module_list.1.subnet.2.bias | 256 | | 0.module_list.2.global_scale | 256 | | 0.module_list.2.global_offset | 256 | | 0.module_list.2.subnet.0.weight | 147456 | | 0.module_list.2.subnet.0.bias | 128 | | 0.module_list.2.subnet.2.weight | 294912 | | 0.module_list.2.subnet.2.bias | 256 | | 0.module_list.3.global_scale | 256 | | 0.module_list.3.global_offset | 256 | | 0.module_list.3.subnet.0.weight | 16384 | | 0.module_list.3.subnet.0.bias | 128 | | 0.module_list.3.subnet.2.weight | 32768 | | 0.module_list.3.subnet.2.bias | 256 | | 0.module_list.4.global_scale | 256 | | 0.module_list.4.global_offset | 256 | | 0.module_list.4.subnet.0.weight | 147456 | | 0.module_list.4.subnet.0.bias | 128 | | 0.module_list.4.subnet.2.weight | 294912 | | 0.module_list.4.subnet.2.bias | 256 | | 0.module_list.5.global_scale | 256 | | 0.module_list.5.global_offset | 256 | | 0.module_list.5.subnet.0.weight | 16384 | | 0.module_list.5.subnet.0.bias | 128 | | 0.module_list.5.subnet.2.weight | 32768 | | 0.module_list.5.subnet.2.bias | 256 | | 0.module_list.6.global_scale | 256 | | 0.module_list.6.global_offset | 256 | | 0.module_list.6.subnet.0.weight | 147456 | | 0.module_list.6.subnet.0.bias | 128 | | 0.module_list.6.subnet.2.weight | 294912 | | 0.module_list.6.subnet.2.bias | 256 | | 0.module_list.7.global_scale | 256 | | 0.module_list.7.global_offset | 256 | | 0.module_list.7.subnet.0.weight | 16384 | | 0.module_list.7.subnet.0.bias | 128 | | 0.module_list.7.subnet.2.weight | 32768 | | 0.module_list.7.subnet.2.bias | 256 | | 1.module_list.0.global_scale | 512 | | 1.module_list.0.global_offset | 512 | | 1.module_list.0.subnet.0.weight | 589824 | | 1.module_list.0.subnet.0.bias | 256 | | 1.module_list.0.subnet.2.weight | 1179648 | | 1.module_list.0.subnet.2.bias | 512 | | 1.module_list.1.global_scale | 512 | | 1.module_list.1.global_offset | 512 | | 1.module_list.1.subnet.0.weight | 65536 | | 1.module_list.1.subnet.0.bias | 256 | | 1.module_list.1.subnet.2.weight | 131072 | | 1.module_list.1.subnet.2.bias | 512 | | 1.module_list.2.global_scale | 512 | | 1.module_list.2.global_offset | 512 | | 1.module_list.2.subnet.0.weight | 589824 | | 1.module_list.2.subnet.0.bias | 256 | | 1.module_list.2.subnet.2.weight | 1179648 | | 1.module_list.2.subnet.2.bias | 512 | | 1.module_list.3.global_scale | 512 | | 1.module_list.3.global_offset | 512 | | 1.module_list.3.subnet.0.weight | 65536 | | 1.module_list.3.subnet.0.bias | 256 | | 1.module_list.3.subnet.2.weight | 131072 | | 1.module_list.3.subnet.2.bias | 512 | | 1.module_list.4.global_scale | 512 | | 1.module_list.4.global_offset | 512 | | 1.module_list.4.subnet.0.weight | 589824 | | 1.module_list.4.subnet.0.bias | 256 | | 1.module_list.4.subnet.2.weight | 1179648 | | 1.module_list.4.subnet.2.bias | 512 | | 1.module_list.5.global_scale | 512 | | 1.module_list.5.global_offset | 512 | | 1.module_list.5.subnet.0.weight | 65536 | | 1.module_list.5.subnet.0.bias | 256 | | 1.module_list.5.subnet.2.weight | 131072 | | 1.module_list.5.subnet.2.bias | 512 | | 1.module_list.6.global_scale | 512 | | 1.module_list.6.global_offset | 512 | | 1.module_list.6.subnet.0.weight | 589824 | | 1.module_list.6.subnet.0.bias | 256 | | 1.module_list.6.subnet.2.weight | 1179648 | | 1.module_list.6.subnet.2.bias | 512 | | 1.module_list.7.global_scale | 512 | | 1.module_list.7.global_offset | 512 | | 1.module_list.7.subnet.0.weight | 65536 | | 1.module_list.7.subnet.0.bias | 256 | | 1.module_list.7.subnet.2.weight | 131072 | | 1.module_list.7.subnet.2.bias | 512 | | 2.module_list.0.global_scale | 1024 | | 2.module_list.0.global_offset | 1024 | | 2.module_list.0.subnet.0.weight | 2359296 | | 2.module_list.0.subnet.0.bias | 512 | | 2.module_list.0.subnet.2.weight | 4718592 | | 2.module_list.0.subnet.2.bias | 1024 | | 2.module_list.1.global_scale | 1024 | | 2.module_list.1.global_offset | 1024 | | 2.module_list.1.subnet.0.weight | 262144 | | 2.module_list.1.subnet.0.bias | 512 | | 2.module_list.1.subnet.2.weight | 524288 | | 2.module_list.1.subnet.2.bias | 1024 | | 2.module_list.2.global_scale | 1024 | | 2.module_list.2.global_offset | 1024 | | 2.module_list.2.subnet.0.weight | 2359296 | | 2.module_list.2.subnet.0.bias | 512 | | 2.module_list.2.subnet.2.weight | 4718592 | | 2.module_list.2.subnet.2.bias | 1024 | | 2.module_list.3.global_scale | 1024 | | 2.module_list.3.global_offset | 1024 | | 2.module_list.3.subnet.0.weight | 262144 | | 2.module_list.3.subnet.0.bias | 512 | | 2.module_list.3.subnet.2.weight | 524288 | | 2.module_list.3.subnet.2.bias | 1024 | | 2.module_list.4.global_scale | 1024 | | 2.module_list.4.global_offset | 1024 | | 2.module_list.4.subnet.0.weight | 2359296 | | 2.module_list.4.subnet.0.bias | 512 | | 2.module_list.4.subnet.2.weight | 4718592 | | 2.module_list.4.subnet.2.bias | 1024 | | 2.module_list.5.global_scale | 1024 | | 2.module_list.5.global_offset | 1024 | | 2.module_list.5.subnet.0.weight | 262144 | | 2.module_list.5.subnet.0.bias | 512 | | 2.module_list.5.subnet.2.weight | 524288 | | 2.module_list.5.subnet.2.bias | 1024 | | 2.module_list.6.global_scale | 1024 | | 2.module_list.6.global_offset | 1024 | | 2.module_list.6.subnet.0.weight | 2359296 | | 2.module_list.6.subnet.0.bias | 512 | | 2.module_list.6.subnet.2.weight | 4718592 | | 2.module_list.6.subnet.2.bias | 1024 | | 2.module_list.7.global_scale | 1024 | | 2.module_list.7.global_offset | 1024 | | 2.module_list.7.subnet.0.weight | 262144 | | 2.module_list.7.subnet.0.bias | 512 | | 2.module_list.7.subnet.2.weight | 524288 | | 2.module_list.7.subnet.2.bias | 1024 | +---------------------------------+------------+ Total Trainable Params: 41.34 M
I think in your case could be by using timm backbone
@mjack3 I was able to match WideResNet50 as well. You see 45.0M here is because I added NormLayers
. I'm pretty sure it shouldn't be like this but I cannot reach comparable result without them.
Besides, if you replace wideresnet50 with resnet18 or one of the transformers, can you still match the parameters?
@mjack3 BTW, timm shouldn't be a problem since the backbone is fixed and not counted in "additional params"
Hello @gathierry
Using ResNet18 I match 2.6M (2.7M paper) using 3-1 and 4.7M (4.9) using 3-3
Obscure..it's a light difference that make me think that AllInOneBlock
is not what we need
For the model with WideResNet50 as feature extractor, there are 8 flow step. Each flow step should have 2 groups of Conv2D-RELU-Conv2D.
But in the flow step implemented here, it looks like every flow step has an AllInOneBlock
block which only has one group of Conv2D-RELU-Conv2D.
Is this understanding correct? Is this going to have an impact on the number of parameters?
@questionstorer Currently, AllInOneBlock
is the only way to match the A.d. x1 hidden channel. You are correct, here we just have one group of Conv2D-Relu-Conv2D.
FastFlow paper has not been accepted yet in any journal or conference. So we only can trust in the idea presented.
@questionstorer nice catch and that's something that confused me as well. If we have 2 groups in each step then the parameter number is doubled. The number for DeiT and CaiT are closer to paper but for resnet the difference will be even larger.
@gathierry Hi, I reconstruct your fastflow code, and my wide_resnet50_2 have the 41.33M(paper:41.3M) A.D. Param, and resnet18 have 4.65M(paper:4.9M) A.D. Param. Also the cait and deit have the same A.D. Param as your code.(7.07M. paper:14.8M), and My wide_resnet50_2 have the LayerNorm like yours.
@Zigars thanks for the feedback, but how do you manage to reduce wrn50 from 45M to 41.3M without removing LayerNorm? Which part did you update?
@gathierry I just seperate the model to encoder(feature_extractor) and decoder(fastflow A.D.) like c-flow, just calculate the decoder's A.D. Param in model loading, and I get the right 41.3M in wrn50 to match the paper's Param. Maybe your concat model have some modules do not set param.requires_grad = False?
Also, in my own code, I added the image_level auc calculate module, and I'm testing resnet18 on MVTec, this cost some time in training. In feature, I will also add visulize module in testing and predict.
Thank you for your open-source code, I learned a lot from your code!
@Zigars so I guess you put LayerNorm in the encoder
? I count it in A.D. params as well since the original wrn50 has no layer norm.
I also tried to set elementwise_affine=False
to remove their learnable parameters only to find the final AUC dropped.
Please correct me if you have different observations.
@gathierry Yes, I put the LayerNorm in the encoder, maybe the original paper also did this. because without LayerNorm, the decoder(FastFlow) can match the paper's A.D. Params. After all, the paper do not have the officials code, we can try it by ourselves.
The model additional parameter number cannot match Table 1. in the paper.