huggingface / pytorch-image-models

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
https://huggingface.co/docs/timm
Apache License 2.0
31.79k stars 4.72k forks source link

[BUG] SwinTransformer Padding Backwards in PatchMerge #2284

Closed collinmccarthy closed 2 weeks ago

collinmccarthy commented 2 weeks ago

Describe the bug In this line the padding for H/W is backwards. I found this out by passing in an image size of (648,888) during validation but it's obvious from the torch docs and the code.

class PatchMerging(nn.Module):
    """ Patch Merging Layer.
    """

    def __init__(
            self,
            dim: int,
            out_dim: Optional[int] = None,
            norm_layer: Callable = nn.LayerNorm,
    ):
        """
        Args:
            dim: Number of input channels.
            out_dim: Number of output channels (or 2 * dim if None)
            norm_layer: Normalization layer.
        """
        super().__init__()
        self.dim = dim
        self.out_dim = out_dim or 2 * dim
        self.norm = norm_layer(4 * dim)
        self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)

    def forward(self, x):
        B, H, W, C = x.shape

        pad_values = (0, 0, 0, W % 2, 0, H % 2)  # Originally (0, 0, 0, H % 2, 0, W % 2) which is wrong
        x = nn.functional.pad(x, pad_values)
        _, H, W, _ = x.shape

        x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
        x = self.norm(x)
        x = self.reduction(x)
        return x

Since the input is B, H, W, C, the padding should be in reverse order like (C_front, C_back, W_front, W_back, H_front, H_back).

Thanks, -Collin

rwightman commented 2 weeks ago

@collinmccarthy indeed, I got it right in the attention block and wrong in patch merging, weird, thanks for catching that! in #2285

collinmccarthy commented 2 weeks ago

Awesome, thanks for the quick fix!