Closed linYDTHU closed 4 months ago
I will have a check asap. However, I suggest you report this issue in the repo of SwinIR or SwinTransformer as I copy the code from SwinIR. @linYDTHU
The official SwinTransformer repo implements this in a slightly different but correct way:
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
They create a tensor with shape (1, H, W, 1)
instead of (1, 1, H, W)
, which ensures the correct dimensions for the attention mask.
SwinIR also implements it in a right way:
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
H, W = x_size
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
@zsyOAOA
Thanks for your report. I will try to retrain a model after fixing this bug, which may further improve the performance of ResShift. @linYDTHU
That would be great. Looking forward to your future updates!
I have re-trained the SR model after fixing this bug, there is only little difference on the synthetic dataset: PSNR: 25.02/25.05 SSIM: 0.6833/0.6837 LPIPS: 0.2076/0.2096
Hence, I retain the original codebase. Thanks for your report again. @linYDTHU
It's great to hear that the performance remains similar after fixing the bug. Could you please share the checkpoint of your retrained model at your earliest convenience? Thank you.
This is a noteworthy phenomenon. I have previously tried removing the attention mask in SwinIR and found that it has little effect on the super-resolution PSNR metric. I think it may be due to the small window size (e.g. 16, 24, 32), which only affects the image edges during shift
@Feynman1999 Thanks for your insightful comment! Larger window sizes may make a bigger difference with different masks, especially in high-resolution settings.
@Feynman1999 Thanks for your insightful comment! Larger window sizes may make a bigger difference with different masks, especially in high-resolution settings.
yep, And high-level tasks (segmentation, detection) have a greater impact than low-level tasks, in my understanding
It's great to hear that the performance remains similar after fixing the bug. Could you please share the checkpoint of your retrained model at your earliest convenience? Thank you.
Hi, you can access the re-trained model for SR via Google drive. @linYDTHU
Thanks for sharing this checkpoint. It will greatly benefit our work as well as the research community.
Hi @zsyOAOA ,
Thank you for your great work! However, I recently discovered a critical bug in implementing the calculate_mask function in the SwinTransformer block.
The potential issue is located from lines 214 to 228 in
model/swin_transformer.py
:I believe the correct way to generate the
img_mask
should be:The difference is in the indexing of
img_mask
within the nested loop. By including the additional colonimg_mask[:, :, h, w]
, the mask is correctly generated.I'm looking forward to hearing your thoughts and potential fix.
Best regards, Donglin