ChristophReich1996 / Swin-Transformer-V2

PyTorch reimplementation of the paper "Swin Transformer V2: Scaling Up Capacity and Resolution" [CVPR 2022].
https://arxiv.org/abs/2111.09883
MIT License
173 stars 14 forks source link

Droping Bias in DeformableSwinBlock #15

Open Matagi1996 opened 6 months ago

Matagi1996 commented 6 months ago

First of all thank you for this wonderful implementation, it was never easier to follow a Paper line for line in Code. This is also the reason I noticed inside the original paper:

The input features are first passed through a 5×5 depthwise convolution to capture local features. Then, GELU activation and a 1×1 convolution is adopted to get the 2D offsets. It is also worth noticing that the bias in 1 × 1 convolution is dropped to alleviate the compulsive shift for all locations.

In your implementation it says: self.offset_network: nn.Module = nn.Sequential( nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=5, stride=offset_downscale_factor, padding=3, groups=in_channels, bias=True), nn.GELU(), nn.Conv2d(in_channels=in_channels, out_channels=2 * self.number_of_heads, kernel_size=1, stride=1, padding=0, bias=True) ) In the repo linked by the Paper https://github.com/LeapLabTHU/DAT/blob/main/models/dat_blocks.py

self.conv_offset = nn.Sequential( nn.Conv2d(self.n_group_channels, self.n_group_channels, kk, stride, pad_size, groups=self.n_group_channels), LayerNormProxy(self.n_group_channels), nn.GELU(), nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, bias=False) )

Should bias of the 1D convolution not be set to False or has the change any practical purpose I am not aware of?