Describe the bug
In this line the padding for H/W is backwards. I found this out by passing in an image size of (648,888) during validation but it's obvious from the torch docs and the code.
class PatchMerging(nn.Module):
""" Patch Merging Layer.
"""
def __init__(
self,
dim: int,
out_dim: Optional[int] = None,
norm_layer: Callable = nn.LayerNorm,
):
"""
Args:
dim: Number of input channels.
out_dim: Number of output channels (or 2 * dim if None)
norm_layer: Normalization layer.
"""
super().__init__()
self.dim = dim
self.out_dim = out_dim or 2 * dim
self.norm = norm_layer(4 * dim)
self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)
def forward(self, x):
B, H, W, C = x.shape
pad_values = (0, 0, 0, W % 2, 0, H % 2) # Originally (0, 0, 0, H % 2, 0, W % 2) which is wrong
x = nn.functional.pad(x, pad_values)
_, H, W, _ = x.shape
x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
x = self.norm(x)
x = self.reduction(x)
return x
Since the input is B, H, W, C, the padding should be in reverse order like (C_front, C_back, W_front, W_back, H_front, H_back).
Describe the bug In this line the padding for H/W is backwards. I found this out by passing in an image size of (648,888) during validation but it's obvious from the torch docs and the code.
Since the input is B, H, W, C, the padding should be in reverse order like (C_front, C_back, W_front, W_back, H_front, H_back).
Thanks, -Collin