Closed prakashjayy closed 2 years ago
This is a typo in the paper I guess. this should be 3x3
conv with padding. below is the python implementation
import torch
import torch.nn as nn
conv3 = nn.Conv2d(24, 24, kernel_size=(3,3), stride=(1, 1), padding=1, bias=False)
conv3.weight.data.fill_(0)
for i in range(conv3.weight.data.shape[0]):
conv3.weight.data[i, i, 1, 1] = 1
x = torch.randn((3, 24, 10, 10))
with torch.no_grad():
y = conv3(x)
print(y.shape)
print(torch.allclose(y, x))
in the paper it says
Can someone help me with the math?