zsyOAOA / ResShift

ResShift: Efficient Diffusion Model for Image Super-resolution by Residual Shifting (NeurIPS@2023 Spotlight, TPAMI@2024)
Other
944 stars 50 forks source link

Critical Bug in 'calculate_mask' Function of SwinTransformer Block #81

Closed linYDTHU closed 4 months ago

linYDTHU commented 4 months ago

Hi @zsyOAOA ,

Thank you for your great work! However, I recently discovered a critical bug in implementing the calculate_mask function in the SwinTransformer block.

The potential issue is located from lines 214 to 228 in model/swin_transformer.py:

def calculate_mask(self, x_size):
    # calculate attention mask for SW-MSA
    H, W = x_size
    img_mask = torch.zeros((1, 1, H, W))  # 1 H W 1
    h_slices = (slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None))
    w_slices = (slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None))
    cnt = 0
    for h in h_slices:
        for w in w_slices:
            img_mask[:, h, w, :] = cnt
            cnt += 1

I believe the correct way to generate the img_mask should be:

def calculate_mask(self, x_size):
    # calculate attention mask for SW-MSA
    H, W = x_size
    img_mask = torch.zeros((1, 1, H, W))  # 1 H W 1
    h_slices = (slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None))
    w_slices = (slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None))
    cnt = 0
    for h in h_slices:
        for w in w_slices:
            img_mask[:, :, h, w] = cnt
            cnt += 1

The difference is in the indexing of img_mask within the nested loop. By including the additional colon img_mask[:, :, h, w], the mask is correctly generated.

I'm looking forward to hearing your thoughts and potential fix.

Best regards, Donglin

zsyOAOA commented 4 months ago

I will have a check asap. However, I suggest you report this issue in the repo of SwinIR or SwinTransformer as I copy the code from SwinIR. @linYDTHU

linYDTHU commented 4 months ago

The official SwinTransformer repo implements this in a slightly different but correct way:

if self.shift_size > 0:
    # calculate attention mask for SW-MSA
    H, W = self.input_resolution
    img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
    h_slices = (slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None))
    w_slices = (slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None))
    cnt = 0
    for h in h_slices:
        for w in w_slices:
            img_mask[:, h, w, :] = cnt
            cnt += 1

They create a tensor with shape (1, H, W, 1) instead of (1, 1, H, W), which ensures the correct dimensions for the attention mask. SwinIR also implements it in a right way:

def calculate_mask(self, x_size):
    # calculate attention mask for SW-MSA
    H, W = x_size
    img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
    h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
    w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
    cnt = 0
    for h in h_slices:
        for w in w_slices:
            img_mask[:, h, w, :] = cnt
            cnt += 1

    mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
    mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

    return attn_mask

@zsyOAOA

zsyOAOA commented 4 months ago

Thanks for your report. I will try to retrain a model after fixing this bug, which may further improve the performance of ResShift. @linYDTHU

linYDTHU commented 4 months ago

That would be great. Looking forward to your future updates!

zsyOAOA commented 4 months ago

I have re-trained the SR model after fixing this bug, there is only little difference on the synthetic dataset: PSNR: 25.02/25.05 SSIM: 0.6833/0.6837 LPIPS: 0.2076/0.2096

Hence, I retain the original codebase. Thanks for your report again. @linYDTHU

linYDTHU commented 4 months ago

It's great to hear that the performance remains similar after fixing the bug. Could you please share the checkpoint of your retrained model at your earliest convenience? Thank you.

Feynman1999 commented 4 months ago

This is a noteworthy phenomenon. I have previously tried removing the attention mask in SwinIR and found that it has little effect on the super-resolution PSNR metric. I think it may be due to the small window size (e.g. 16, 24, 32), which only affects the image edges during shift

linYDTHU commented 4 months ago

@Feynman1999 Thanks for your insightful comment! Larger window sizes may make a bigger difference with different masks, especially in high-resolution settings.

Feynman1999 commented 4 months ago

@Feynman1999 Thanks for your insightful comment! Larger window sizes may make a bigger difference with different masks, especially in high-resolution settings.

yep, And high-level tasks (segmentation, detection) have a greater impact than low-level tasks, in my understanding

zsyOAOA commented 4 months ago

It's great to hear that the performance remains similar after fixing the bug. Could you please share the checkpoint of your retrained model at your earliest convenience? Thank you.

Hi, you can access the re-trained model for SR via Google drive. @linYDTHU

linYDTHU commented 4 months ago

Thanks for sharing this checkpoint. It will greatly benefit our work as well as the research community.