csguoh / MambaIR

[ECCV2024] An official pytorch implement of the paper "MambaIR: A simple baseline for image restoration with state-space model".
Apache License 2.0
453 stars 41 forks source link

Lightweight models broken as of 15 October #69

Closed umbertov closed 6 days ago

umbertov commented 1 week ago

As of 15 october (commit 06dc6cdd2fd87df0c4462603daa6bb6d1c43e7b3 and c308c30f1e8d81153547378012e61ab86d7e2ef4), the model architecture for the lightweight models is incompatible with the one present in the Google Drive checkpoints.

Trying to load the model gives errors like:

Missing key(s) in state_dict: "layers.0.residual_group.blocks.0.conv_blk.cab.1.weight", "layers.0.residual_group.blocks.0.conv_blk.cab.1.bias", "layers.0.residual_group.blocks.0.conv_blk.cab.3.weight", "layers.0.residual_group.blocks.0.conv_blk.cab.3.bias", "layers.0.residual_group.blocks.0.conv_blk.cab.4.weight", "layers.0.residual_group.blocks.0.conv_blk.cab.4.bias", "layers.0.residual_group.blocks.0.conv_blk.cab.5.attention.1.weight", "layers.0.residual_group.blocks.0.conv_blk.cab.5.attention.1.bias", "layers.0.residual_group.blocks.0.conv_blk.cab.5.attention.3.weight", "layers.0.residual_group.blocks.0.conv_blk.cab.5.attention.3.bias", "layers.0.residual_group.blocks.1.conv_blk.cab.1.weight", "layers.0.residual_group.blocks.1.conv_blk.cab.1.bias", "layers.0.residual_group.blocks.1.conv_blk.cab.3.weight", "layers.0.residual_group.blocks.1.conv_blk.cab.3.bias", "layers.0.residual_group.blocks.1.conv_blk.cab.4.weight", "layers.0.residual_group.blocks.1.conv_blk.cab.4.bias", "layers.0.residual_group.blocks.1.conv_blk.cab.5.attention.1.weight", "layers.0.residual_group.blocks.1.conv_blk.cab.5.attention.1.bias", ...[omitted]

Unexpected key(s) in state_dict: "layers.0.residual_group.blocks.0.conv_blk.cab.2.weight", "layers.0.residual_group.blocks.0.conv_blk.cab.2.bias", "layers.0.residual_group.blocks.0.conv_blk.cab.3.attention.1.weight", "layers.0.residual_group.blocks.0.conv_blk.cab.3.attention.1.bias", "layers.0.residual_group.blocks.0.conv_blk.cab.3.attention.3.weight", "layers.0.residual_group.blocks.0.conv_blk.cab.3.attention.3.bias", "layers.0.residual_group.blocks.1.conv_blk.cab.2.weight", "layers.0.residual_group.blocks.1.conv_blk.cab.2.bias", "layers.0.residual_group.blocks.1.conv_blk.cab.3.attention.1.weight", "layers.0.residual_group.blocks.1.conv_blk.cab.3.attention.1.bias", "layers.0.residual_group.blocks.1.conv_blk.cab.3.attention.3.weight", "layers.0.residual_group.blocks.1.conv_blk.cab.3.attention.3.bias", "layers.0.residual_group.blocks.2.conv_blk.cab.2.weight", "layers.0.residual_group.blocks.2.conv_blk.cab.2.bias", "layers.0.residual_group.blocks.2.conv_blk.cab.3.attention.1.weight", "layers.0.residual_group.blocks.2.conv_blk.cab.3.attention.1.bias", "layers.0.residual_group.blocks.2.conv_blk.cab.3.attention.3.weight", "layers.0.residual_group.blocks.2.conv_blk.cab.3.attention.3.bias", "layers.0.residual_group.blocks.3.conv_blk.cab.2.weight", "layers.0.residual_group.blocks.3.conv_blk.cab.2.bias", "layers.0.residual_group.blocks.3.conv_blk.cab.3.attention.1.weight", "layers.0.residual_group.blocks.3.conv_blk.cab.3.attention.1.bias", "layers.0.residual_group.blocks.3.conv_blk.cab.3.attention.3.weight", "layers.0.residual_group.blocks.3.conv_blk.cab.3.attention.3.bias", "layers.0.residual_group.blocks.4.conv_blk.cab.2.weight", "layers.0.residual_group.blocks.4.conv_blk.cab.2.bias", "layers.0.residual_group.blocks.4.conv_blk.cab.3.attention.1.weight", "layers.0.residual_group.blocks.4.conv_blk.cab.3.attention.1.bias", "layers.0.residual_group.blocks.4.conv_blk.cab.3.attention.3.weight", "layers.0.residual_group.blocks.4.conv_blk.cab.3.attention.3.bias", "layers.0.residual_group.blocks.5.conv_blk.cab.2.weight", "layers.0.residual_group.blocks.5.conv_blk.cab.2.bias", "layers.0.residual_group.blocks.5.conv_blk.cab.3.attention.1.weight", "layers.0.residual_group.blocks.5.conv_blk.cab.3.attention.1.bias", ...[omitted]

The solution was to revert changes in the aforementioned commits (for the x4 model):

diff --git a/options/test/test_MambaIR_lightSR_x4.yml b/options/test/test_MambaIR_lightSR_x4.yml
index ea60385..2477d45 100644
--- a/options/test/test_MambaIR_lightSR_x4.yml
+++ b/options/test/test_MambaIR_lightSR_x4.yml
@@ -60,10 +60,10 @@ network_g:
   in_chans: 3
   img_size: 64
   img_range: 1.
-  d_state: 16
+  d_state: 10
   depths: [6, 6, 6, 6]
   embed_dim: 60
-  mlp_ratio: 1.5
+  mlp_ratio: 1.2
   upsampler: 'pixelshuffledirect'
   resi_connection: '1conv'

and

diff --git a/basicsr/archs/mambair_arch.py b/basicsr/archs/mambair_arch.py
index da44128..43e503c 100644
--- a/basicsr/archs/mambair_arch.py
+++ b/basicsr/archs/mambair_arch.py
@@ -40,14 +40,23 @@ class ChannelAttention(nn.Module):
 class CAB(nn.Module):
     def __init__(self, num_feat, is_light_sr= False, compress_ratio=3,squeeze_factor=30):
         super(CAB, self).__init__()
-        if is_light_sr: # a larger compression ratio is used for light-SR
-            compress_ratio = 6
-        self.cab = nn.Sequential(
-            nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
-            nn.GELU(),
-            nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
-            ChannelAttention(num_feat, squeeze_factor)
-        )
+        if is_light_sr: # we use dilated-conv & DWConv for lightSR for a large ERF
+            compress_ratio = 2
+            self.cab = nn.Sequential(
+                nn.Conv2d(num_feat, num_feat // compress_ratio, 1, 1, 0),
+                nn.Conv2d(num_feat//compress_ratio, num_feat // compress_ratio, 3, 1, 1,groups=num_feat//compress_ratio),
+                nn.GELU(),
+                nn.Conv2d(num_feat // compress_ratio, num_feat, 1, 1, 0),
+                nn.Conv2d(num_feat, num_feat, 3,1,padding=2,groups=num_feat,dilation=2),
+                ChannelAttention(num_feat, squeeze_factor)
+            )
+        else:
+            self.cab = nn.Sequential(
+                nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
+                nn.GELU(),
+                nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
+                ChannelAttention(num_feat, squeeze_factor)
+            )

     def forward(self, x):
         return self.cab(x)

Are you planning to release checkpoints for the update architectures, or shall you revert the changes in the repository?

csguoh commented 6 days ago

Hi, thanks for your reminder. We have updated the ckpt weights download link in README to match the one in the existing code. This lightSR version is smaller than previous one and the performance is in consistent with our latest Arxiv version.

umbertov commented 6 days ago

Hey, thanks for the fast reaction! Glad to hear that