Ayews / M3Net

The implementation of 'M3Net: Multilevel, Mixed and Multistage Attention Network for Salient Object Detection'.
MIT License
8 stars 4 forks source link

Not able to download pretrained weights from Baidu link #2

Closed anilsathyan7 closed 1 year ago

anilsathyan7 commented 1 year ago

The current link does not seem to be working. Is there an alternate link (google drive)?

Ayews commented 1 year ago

Thank you for your interest in our work! We apologize for the inconvenience caused by the download link. To make it easier for everyone to download, we have added a Google Drive link. You can download the pretrained weights from the following address:

M3Net-R50_224 M3Net-SwinB_384

If you encounter any other issues during the download or usage process, please feel free to contact us.

anilsathyan7 commented 1 year ago

It's not available for download for anyone with link, still requires separate access permission. Might need to set sharing permissions to 'Anyone with the link' in sharing options, i guess.

Ayews commented 1 year ago

Ah, thank you for the reminder! We have modified the access permissions.

anilsathyan7 commented 1 year ago

Error loading model with default config for checkpoint M3Net-SwinB_384.pth

!python train_test.py --test True --data_root '/content/images' --save_model '/content/M3Net/pretrained_model/'

Starting test. /usr/local/lib/python3.10/dist-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3526.) return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined] Traceback (most recent call last): File "/content/M3Net/train_test.py", line 34, in testing(args=args) File "/content/M3Net/test.py", line 49, in testing model.load_state_dict(torch.load(args.save_model+args.method+'.pth')) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for M3Net: size mismatch for interact1.interact1.q1.weight: copying a param with shape torch.Size([512, 256]) from checkpoint, the shape in current model is torch.Size([384, 256]). size mismatch for interact1.interact1.proj.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 384]). size mismatch for interact1.interact1.k2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 512]). size mismatch for interact1.interact1.v2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 512]). size mismatch for interact2.interact1.q1.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]). size mismatch for interact2.interact1.proj.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 384]). size mismatch for interact2.interact1.k2.weight: copying a param with shape torch.Size([512, 256]) from checkpoint, the shape in current model is torch.Size([384, 256]). size mismatch for interact2.interact1.v2.weight: copying a param with shape torch.Size([512, 256]) from checkpoint, the shape in current model is torch.Size([384, 256]). size mismatch for interact2.interact2.q1.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]). size mismatch for interact2.interact2.proj.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 384]). size mismatch for interact2.interact2.k2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 512]). size mismatch for interact2.interact2.v2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 512]). size mismatch for interact3.interact1.q1.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([384, 64]). size mismatch for interact3.interact1.proj.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([64, 384]). size mismatch for interact3.interact1.k2.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]). size mismatch for interact3.interact1.v2.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]). size mismatch for interact3.interact2.q1.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([384, 64]). size mismatch for interact3.interact2.proj.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([64, 384]). size mismatch for interact3.interact2.k2.weight: copying a param with shape torch.Size([512, 256]) from checkpoint, the shape in current model is torch.Size([384, 256]). size mismatch for interact3.interact2.v2.weight: copying a param with shape torch.Size([512, 256]) from checkpoint, the shape in current model is torch.Size([384, 256]). size mismatch for decoder.mixatt1.mlp1.0.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]). size mismatch for decoder.mixatt1.mlp1.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.mlp1.2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]). size mismatch for decoder.mixatt1.mlp1.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.blocks.0.windowatt.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.blocks.0.windowatt.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.blocks.0.windowatt.attn.qkv.weight: copying a param with shape torch.Size([1536, 512]) from checkpoint, the shape in current model is torch.Size([1152, 384]). size mismatch for decoder.mixatt1.blocks.0.windowatt.attn.qkv.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([1152]). size mismatch for decoder.mixatt1.blocks.0.windowatt.attn.proj.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]). size mismatch for decoder.mixatt1.blocks.0.windowatt.attn.proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.blocks.0.globalatt.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.blocks.0.globalatt.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.blocks.0.globalatt.attn.qkv.weight: copying a param with shape torch.Size([1536, 512]) from checkpoint, the shape in current model is torch.Size([1152, 384]). size mismatch for decoder.mixatt1.blocks.0.globalatt.attn.proj.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]). size mismatch for decoder.mixatt1.blocks.0.globalatt.attn.proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.blocks.0.norm.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.blocks.0.norm.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.blocks.0.mlp.0.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]). size mismatch for decoder.mixatt1.blocks.0.mlp.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.blocks.0.mlp.2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]). size mismatch for decoder.mixatt1.blocks.0.mlp.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.blocks.1.windowatt.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.blocks.1.windowatt.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.blocks.1.windowatt.attn.qkv.weight: copying a param with shape torch.Size([1536, 512]) from checkpoint, the shape in current model is torch.Size([1152, 384]). size mismatch for decoder.mixatt1.blocks.1.windowatt.attn.qkv.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([1152]). size mismatch for decoder.mixatt1.blocks.1.windowatt.attn.proj.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]). size mismatch for decoder.mixatt1.blocks.1.windowatt.attn.proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.blocks.1.globalatt.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.blocks.1.globalatt.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.blocks.1.globalatt.attn.qkv.weight: copying a param with shape torch.Size([1536, 512]) from checkpoint, the shape in current model is torch.Size([1152, 384]). size mismatch for decoder.mixatt1.blocks.1.globalatt.attn.proj.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]). size mismatch for decoder.mixatt1.blocks.1.globalatt.attn.proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.blocks.1.norm.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.blocks.1.norm.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.blocks.1.mlp.0.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]). size mismatch for decoder.mixatt1.blocks.1.mlp.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.blocks.1.mlp.2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]). size mismatch for decoder.mixatt1.blocks.1.mlp.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.norm2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.norm2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt1.mlp2.0.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([128, 384]). size mismatch for decoder.mixatt2.mlp1.0.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([384, 64]). size mismatch for decoder.mixatt2.mlp1.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.mlp1.2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]). size mismatch for decoder.mixatt2.mlp1.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.blocks.0.windowatt.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.blocks.0.windowatt.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.blocks.0.windowatt.attn.qkv.weight: copying a param with shape torch.Size([1536, 512]) from checkpoint, the shape in current model is torch.Size([1152, 384]). size mismatch for decoder.mixatt2.blocks.0.windowatt.attn.qkv.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([1152]). size mismatch for decoder.mixatt2.blocks.0.windowatt.attn.proj.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]). size mismatch for decoder.mixatt2.blocks.0.windowatt.attn.proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.blocks.0.globalatt.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.blocks.0.globalatt.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.blocks.0.globalatt.attn.qkv.weight: copying a param with shape torch.Size([1536, 512]) from checkpoint, the shape in current model is torch.Size([1152, 384]). size mismatch for decoder.mixatt2.blocks.0.globalatt.attn.proj.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]). size mismatch for decoder.mixatt2.blocks.0.globalatt.attn.proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.blocks.0.norm.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.blocks.0.norm.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.blocks.0.mlp.0.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]). size mismatch for decoder.mixatt2.blocks.0.mlp.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.blocks.0.mlp.2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]). size mismatch for decoder.mixatt2.blocks.0.mlp.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.blocks.1.windowatt.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.blocks.1.windowatt.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.blocks.1.windowatt.attn.qkv.weight: copying a param with shape torch.Size([1536, 512]) from checkpoint, the shape in current model is torch.Size([1152, 384]). size mismatch for decoder.mixatt2.blocks.1.windowatt.attn.qkv.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([1152]). size mismatch for decoder.mixatt2.blocks.1.windowatt.attn.proj.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]). size mismatch for decoder.mixatt2.blocks.1.windowatt.attn.proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.blocks.1.globalatt.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.blocks.1.globalatt.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.blocks.1.globalatt.attn.qkv.weight: copying a param with shape torch.Size([1536, 512]) from checkpoint, the shape in current model is torch.Size([1152, 384]). size mismatch for decoder.mixatt2.blocks.1.globalatt.attn.proj.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]). size mismatch for decoder.mixatt2.blocks.1.globalatt.attn.proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.blocks.1.norm.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.blocks.1.norm.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.blocks.1.mlp.0.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]). size mismatch for decoder.mixatt2.blocks.1.mlp.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.blocks.1.mlp.2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]). size mismatch for decoder.mixatt2.blocks.1.mlp.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.norm2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.norm2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for decoder.mixatt2.mlp2.0.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([64, 384]).

Ayews commented 1 year ago

Just modify the embed_dim parameter. For example, use this line of code to create a model: model = M3Net(embed_dim=384,dim=64,img_size=img_size,method='M3Net-S')

anilsathyan7 commented 1 year ago

It is working, thank you!

hungtooc commented 11 months ago

Thank you for your interest in our work! We apologize for the inconvenience caused by the download link. To make it easier for everyone to download, we have added a Google Drive link. You can download the pretrained weights from the following address:

M3Net-R50_224 M3Net-SwinB_384

If you encounter any other issues during the download or usage process, please feel free to contact us.

Hi, can you share Drive link for backbone EfficientNet?

Ayews commented 11 months ago

Unfortunately, we were unable to provide pre-trained weights, but you can try to train it yourself. The implementation of EfficientNet in our code is available. If you want to use other EfficientNet specifications, please modify the input channel size of self.proj in M3Net.py and specify the input image size you expect when instantiating the model.

Otherwise, you may tell me the pre-training weight specifications you expect, as EfficientNet provides various specifications, b0-b8. I will see if I can find time to train one.

hungtooc commented 11 months ago

Hi @Ayews. If possible, I wish to get EfficientNet-b7, it will save me a lot of research time!

Ayews commented 11 months ago

M3Net-Eb7_224 preds

Note that the pre-trained weights we use are adv-efficientnet-b7-4652b6dd.pth, and the 224x224 input size is adopted. The training time is about 12 hours on an RTX 3090, and when the input is expanded to 384x384, the training time will increase to about 40 hours.

The evaluation results are shown below. ========>> Date: 2023-12-19 10:36:26.414899 <<======== ========>> Dataset: DUT-O <<======== [M3Net] mae: 0.057 maxf: 0.818 avgf: 0.796 adpf: 0.794 maxe: 0.894 avge: 0.883 adpe: 0.894 sm: 0.857 wfm: 0.784
========>> Dataset: DUTS-TE <<======== [M3Net] mae: 0.035 maxf: 0.89 avgf: 0.868 adpf: 0.855 maxe: 0.941 avge: 0.933 adpe: 0.933 sm: 0.902 wfm: 0.856
========>> Dataset: ECSSD <<======== [M3Net] mae: 0.028 maxf: 0.949 avgf: 0.933 adpf: 0.935 maxe: 0.967 avge: 0.959 adpe: 0.963 sm: 0.938 wfm: 0.927
========>> Dataset: HKU-IS <<======== [M3Net] mae: 0.025 maxf: 0.94 avgf: 0.921 adpf: 0.919 maxe: 0.968 avge: 0.96 adpe: 0.966 sm: 0.931 wfm: 0.915
========>> Dataset: PASCAL-S <<======== [M3Net] mae: 0.06 maxf: 0.866 avgf: 0.846 adpf: 0.846 maxe: 0.913 avge: 0.904 adpe: 0.91 sm: 0.869 wfm: 0.83
========>> Dataset: SOD <<======== [M3Net] mae: 0.08 maxf: 0.851 avgf: 0.842 adpf: 0.841 maxe: 0.878 avge: 0.856 adpe: 0.852 sm: 0.828 wfm: 0.799