ChristophReich1996 / Swin-Transformer-V2

PyTorch reimplementation of the paper "Swin Transformer V2: Scaling Up Capacity and Resolution" [CVPR 2022].
https://arxiv.org/abs/2111.09883
MIT License
173 stars 14 forks source link

problem about DeformableSwinTransformerBlock #4

Closed nullxjx closed 2 years ago

nullxjx commented 2 years ago

when i feed the input data whose shape is [1, 3, 240, 250], there is an error at https://github.com/ChristophReich1996/Swin-Transformer-V2/blob/d1b89227ef0045c3ab667a4f2cdea9ec4f240236/swin_transformer_v2/model_parts.py#L574 , the error is listed below:

RuntimeError: The size of tensor a (48) must match the size of tensor b (256) at non-singleton dimension 2

it appears that the shape of _self.default_grid.repeat_interleave(repeats=offsets.shape[0], dim=0)_ and offsets is different, I wonder if you have the same problem and it will be of great help if you can help me solve it, thank you~

ChristophReich1996 commented 2 years ago

Hi @nullxjx,

I guess there are two issues with your code. First, if you initialize the Swin-Transformer V2 you need to specify the input resolution you would like to utilize. Second, the resolution [1, 3, 240, 250] is invalid, since, due to the hierarchical architecture of the Swin-Transformer, the feature maps need to be devisable four times by two in the optimal case. But to be sure what the problem is, please share your full code, if possible

Cheers Christoph

nullxjx commented 2 years ago

Hi @ChristophReich1996 ,

Here is the thing, currently i am using your code for image super resolution task, the difference between your classification task and mine is that i dont have a hierarchical architecture, this is, after every stage, the shape of the feature maps remain the same. And your code serves as the backbone in my network architecture.

my full code is available at https://github.com/nullxjx/Swinir-V2.

the above error appears at https://github.com/nullxjx/Swinir-V2/blob/98c33cc9479f13823b6a841bad1eec7ac15018a9/models/utils_for_swinir_v2.py#L619

please check the code from https://github.com/nullxjx/Swinir-V2/blob/98c33cc9479f13823b6a841bad1eec7ac15018a9/models/network_swinir_v2.py#L739

Thank you~ XJX

ChristophReich1996 commented 2 years ago

Hi @nullxjx,

I can't retrace your full code but I think the error occurs due to the incorrect image shape given to the model's constructor. Also, be aware that the patch size must be compatible with you image shape you want to process. Padding the input image to a resolution of [1, 3, 256, 256] and setting the image shape to (256, 256) would be the easiest fix I guess. Please refer to the provided example for the correct usage.

Cheers Christoph

Breeze-Zero commented 2 years ago

Hi @ChristophReich1996 ,

When using DeformableSwinTransformerBlock, I also found a problem. Here's a quick example: when I use "with torch.cuda.amp.autocast(): " to forward model (set use_deformable_block=True, and delete ": torch.Tensor" from your code which will result in a half-precision error),It still shows errors with mismatched accuracy. This is the line of code that finally reported the error: https://github.com/ChristophReich1996/Swin-Transformer-V2/blob/d1b89227ef0045c3ab667a4f2cdea9ec4f240236/swin_transformer_v2/model_parts.py#L579

ChristophReich1996 commented 2 years ago

Hi @834799106, please provide the full error message and the code, otherwise, I can only guess what the bug might be.

Breeze-Zero commented 2 years ago

Hi @ChristophReich1996 ,

I've changed a lot of the code for my own tasks now, but the core is pretty much the same. Here's Traceback(The format might be a little messy)

RuntimeError Traceback (most recent call last) /tmp/ipykernel_3054492/3751280697.py in 5 with torch.cuda.amp.autocast(): ----> 6 d = model(a) /model_parts3d.py in forward(self, input) 586 # Apply sampling grid 587 input_resampled = F.grid_sample(input=input, grid=offset_grid.clip(min=-1, max=1), --> 588 mode="bilinear", align_corners=True, padding_mode="reflection") 589 # Reshape resampled tensor again to [batch size, channels,spatial, height, width] 590 input_resampled = input_resampled.view(batch_size, channels, spatial, height, width)

~/.conda/envs/X/lib/python3.7/site-packages/torch/nn/functional.py in grid_sample(input, grid, mode, padding_mode, align_corners) 3834 align_corners = False 3835 -> 3836 return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners) 3837 3838

RuntimeError: grid_sampler(): expected input and grid to have same dtype, but input has c10::Half and grid has float

ChristophReich1996 commented 2 years ago

I think simply casting the default grid to the same data type as the input should get the job done. I thought autocast would do this automatically.

# Cast offset grid to input data type
if input.dtype != self.default_grid.dtype:
    self.default_grid = self.default_grid.type(input.dtype)

This fix works for me if I perform this forward and backward pass

input = torch.rand(2, 3, 256, 256).cuda().half()
swin_transformer: SwinTransformerV2 = swin_transformer_v2_t(in_channels=3,
                                                              window_size=8,
                                                              input_resolution=(256, 256),
                                                              sequential_self_attention=False,
                                                              use_checkpoint=False,
                                                              use_deformable_block=True)
swin_transformer.cuda()
# Perform forward pass
with torch.cuda.amp.autocast():
    features: List[torch.Tensor] = swin_transformer(input)
# Print shape of features
for feature in features:
    print(feature.shape)
features[-1].sum().backward()

Related PyTorch issue.

Breeze-Zero commented 2 years ago

After modification, it can work normally. Thanks!

nullxjx commented 2 years ago

Hi,I wonder if the problem is here. When I define the model, I set the input_resolution to [48, 48], but when I test the model on an input [240, 250], the error occurs. The problem is that I train the model on images whose size is 48x48, but I will test the model on random images whose size are unknown,how should I handle it?

my full code is available at https://github.com/nullxjx/Swinir-V2, you can run the network_swinir_v2.py to see the error.

ChristophReich1996 commented 2 years ago

Hi @nullxjx, Yes as mentioned in my previous response this seems to be the problem. The model needs to know the input resolution. If you want to change the resolution use the respective method, as explained in the README.

nullxjx commented 2 years ago

Hi, I solve the above mentioned problem, but I find a new problem, when I put both the input and the model to cuda device, as the following code in https://github.com/ChristophReich1996/Swin-Transformer-V2/blob/75e5ac9ebb177f5b0accca31460ced323fa7b0e1/example.py#L10

def main() -> None:
    # Make input tensor and init Swin Transformer V2, for the custom deformable version set use_deformable_block=True
    input = torch.rand(2, 3, 256, 256).cuda(0)
    swin_transformer: SwinTransformerV2 = swin_transformer_v2_t(in_channels=3,
                                                                window_size=8,
                                                                input_resolution=(256, 256),
                                                                sequential_self_attention=False,
                                                                use_checkpoint=False).cuda(0)
    # Perform forward pass
    output = swin_transformer(input)
    print(output.shape)

    # Update the resolution and window size of the Swin Transformer V2 and init new input
    swin_transformer.update_resolution(new_window_size=16, new_input_resolution=(512, 512))
    input = torch.rand(2, 3, 512, 512).cuda(0)
    # Perform forward pass
    output = swin_transformer(input)
    print(output.shape)

the below error occurs,

Traceback (most recent call last): File "C:/Users/XJX/Desktop/Swin-Transformer-V2/example.py", line 29, in main() File "C:/Users/XJX/Desktop/Swin-Transformer-V2/example.py", line 24, in main output = swin_transformer(input) File "D:\anaconda3\envs\swin_v2\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, kwargs) File "C:\Users\XJX\Desktop\Swin-Transformer-V2\swin_transformer_v2\model.py", line 107, in forward output: torch.Tensor = stage(output) File "D:\anaconda3\envs\swin_v2\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, *kwargs) File "C:\Users\XJX\Desktop\Swin-Transformer-V2\swin_transformer_v2\model_parts.py", line 756, in forward output: torch.Tensor = block(output) File "D:\anaconda3\envs\swin_v2\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(input, kwargs) File "C:\Users\XJX\Desktop\Swin-Transformer-V2\swin_transformer_v2\model_parts.py", line 449, in forward output_attention: torch.Tensor = self.window_attention(output_patches, mask=self.attention_mask) File "D:\anaconda3\envs\swin_v2\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, kwargs) File "C:\Users\XJX\Desktop\Swin-Transformer-V2\swin_transformer_v2\model_parts.py", line 304, in forward mask=mask) File "C:\Users\XJX\Desktop\Swin-Transformer-V2\swin_transformer_v2\model_parts.py", line 211, in __self_attention attention_map: torch.Tensor = attention_map + self.get_relative_positional_encodings() File "C:\Users\XJX\Desktop\Swin-Transformer-V2\swin_transformer_v2\model_parts.py", line 180, in get_relative_positional_encodings relative_position_bias: torch.Tensor = self.meta_network(self.relative_coordinates_log) File "D:\anaconda3\envs\swin_v2\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, *kwargs) File "D:\anaconda3\envs\swin_v2\lib\site-packages\torch\nn\modules\container.py", line 117, in forward input = module(input) File "D:\anaconda3\envs\swin_v2\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(input, kwargs) File "D:\anaconda3\envs\swin_v2\lib\site-packages\torch\nn\modules\linear.py", line 93, in forward return F.linear(input, self.weight, self.bias) File "D:\anaconda3\envs\swin_v2\lib\site-packages\torch\nn\functional.py", line 1690, in linear ret = torch.addmm(bias, input, weight.t()) RuntimeError: Tensor for argument #2 'mat1' is on CPU, but expected it to be on GPU (while checking arguments for addmm)

the problem is that at line https://github.com/ChristophReich1996/Swin-Transformer-V2/blob/75e5ac9ebb177f5b0accca31460ced323fa7b0e1/swin_transformer_v2/model_parts.py#L180 , the variable self.relative_coordinates_log is not on device cuda,you should put all variables on the same device.

ChristophReich1996 commented 2 years ago

Fixed.

nullxjx commented 2 years ago

new problem about your recent fix at https://github.com/ChristophReich1996/Swin-Transformer-V2/blob/d54e0cfc480afad1ba1c7d2a818849e5658e5e6e/swin_transformer_v2/model_parts.py#L537

D:\anaconda3\envs\swin_v2\python.exe C:/Users/XJX/Desktop/Swin-Transformer-V2/example.py Traceback (most recent call last): File "C:/Users/XJX/Desktop/Swin-Transformer-V2/example.py", line 30, in main() File "C:/Users/XJX/Desktop/Swin-Transformer-V2/example.py", line 16, in main use_deformable_block=True).cuda(0) File "C:\Users\XJX\Desktop\Swin-Transformer-V2\swin_transformer_v2\model.py", line 137, in swin_transformer_v2_t **kwargs) File "C:\Users\XJX\Desktop\Swin-Transformer-V2\swin_transformer_v2\model.py", line 77, in init use_deformable_block=use_deformable_block and (index > 0) File "C:\Users\XJX\Desktop\Swin-Transformer-V2\swin_transformer_v2\model_parts.py", line 732, in init for index in range(depth)]) File "C:\Users\XJX\Desktop\Swin-Transformer-V2\swin_transformer_v2\model_parts.py", line 732, in for index in range(depth)]) File "C:\Users\XJX\Desktop\Swin-Transformer-V2\swin_transformer_v2\model_parts.py", line 521, in init self.make_default_offsets() File "C:\Users\XJX\Desktop\Swin-Transformer-V2\swin_transformer_v2\model_parts.py", line 537, in __make_default_offsets device=self.tau.device) File "D:\anaconda3\envs\swin_v2\lib\site-packages\torch\nn\modules\module.py", line 779, in getattr type(self).name__, name)) torch.nn.modules.module.ModuleAttributeError: 'DeformableSwinTransformerBlock' object has no attribute 'tau'

进程已结束,退出代码1

ChristophReich1996 commented 2 years ago

Also fixed.