Open TedThemistokleous opened 6 months ago
Heres a simple example i made using pytorch to show how to do this mathematically.
import torch
class PoolingConv(torch.nn.Module):
def __init__(self):
super(PoolingConv, self).__init__()
self.avg_pool = torch.nn.AvgPool2d(
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
)
self.conv = torch.nn.Conv2d(in_channels=3,
out_channels=2,
kernel_size=(1, 1),
bias=False)
def forward(self, x):
return self.conv(self.avg_pool(x))
if __name__ == "__main__":
mod = PoolingConv().eval()
x = torch.randn(1, 3, 5, 5)
torch_out = mod(x)
# Pad the convolution weights to make them 3x3
conv_weights = mod.conv.weight # [2, 3, 1, 1]
padded_conv_weights = torch.nn.functional.pad(conv_weights, (1, 1, 1, 1)) # [2, 3, 3, 3]
avg_pool_weights = 1 / 9 * torch.ones(3, 1, 3, 3) # [3, 1, 3, 3]
# Using the Associativity property, we can compute a "fused" weight (this can be const folded during compile)
# This needs to be grouped (group_size=1) because in pooling there is no inter-channel correlations
fused_weight = torch.nn.functional.conv2d(padded_conv_weights, avg_pool_weights, padding="same", groups=3) # [2, 3, 3, 3]
# Apply the "fused" weight to the original input
fused_out = torch.nn.functional.conv2d(x, fused_weight, padding="same")
# Hopefully the result is equivalent to running the PoolingConv module
assert torch.allclose(torch_out, fused_out)
print(torch_out)
print(fused_out)
Fuse average pooling with convolution
We could convert the average pooling to a convolution(and get rid of the pad_kernel). Then we will have a convolution into another convolution. Convolution is not commutative like gemms, but I believe we could rewrite it as a backwards convolution where we apply convolution to the weights.
Related to inceptionV3 and other models