ROCm / AMDMIGraphX

AMD's graph optimization engine.
https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/
MIT License
183 stars 84 forks source link

Fuse Average Pooling with Convolution #2867

Open TedThemistokleous opened 6 months ago

TedThemistokleous commented 6 months ago

Fuse average pooling with convolution

@77 = gpu::code_object[code_object=9344,symbol_name=pad_kernel,global=262848,local=1024,](@57,@76) -> float_type, {1, 192, 37, 37}, {262848, 1369, 37, 1}
@78 = load[offset=705600,end=1646400](@1) -> float_type, {1, 192, 35, 35}, {235200, 1225, 35, 1}
@79 = gpu::pooling[mode=average,padding={0, 0, 0, 0},padding_mode=0,stride={1, 1},lengths={3, 3},dilations={1, 1},ceil_mode=0,lp_order=2,dyn_global=0](@77,@78) -> float_type, {1, 192, 35, 35}, {235200, 1225, 35, 1}
@80 = hip::hip_copy_literal[id=main:@literal:147] -> float_type, {32}, {1}
@81 = hip::hip_copy_literal[id=main:@literal:148] -> float_type, {32, 192, 1, 1}, {192, 1, 1, 1}
@82 = load[offset=1960000,end=2116800](@1) -> float_type, {1, 32, 35, 35}, {39200, 1225, 35, 1}
@83 = broadcast[axis=1,out_lens={1, 32, 35, 35}](@80) -> float_type, {1, 32, 35, 35}, {0, 1, 0, 0}
@84 = gpu::code_object[code_object=7312,symbol_name=mlir_convolution_add_relu,global=9856,local=64,](@83,@79,@81,@82) -> float_type, {1, 32, 35, 35}, {39200, 1225, 35, 1}

We could convert the average pooling to a convolution(and get rid of the pad_kernel). Then we will have a convolution into another convolution. Convolution is not commutative like gemms, but I believe we could rewrite it as a backwards convolution where we apply convolution to the weights.

Related to inceptionV3 and other models

shivadbhavsar commented 1 month ago

Heres a simple example i made using pytorch to show how to do this mathematically.

import torch

class PoolingConv(torch.nn.Module):

    def __init__(self):
        super(PoolingConv, self).__init__()

        self.avg_pool = torch.nn.AvgPool2d(
            kernel_size=(3, 3),
            stride=(1, 1),
            padding=(1, 1),
        )
        self.conv = torch.nn.Conv2d(in_channels=3,
                                    out_channels=2,
                                    kernel_size=(1, 1),
                                    bias=False)

    def forward(self, x):
        return self.conv(self.avg_pool(x))

if __name__ == "__main__":
    mod = PoolingConv().eval()
    x = torch.randn(1, 3, 5, 5)
    torch_out = mod(x)

    # Pad the convolution weights to make them 3x3
    conv_weights = mod.conv.weight # [2, 3, 1, 1]
    padded_conv_weights = torch.nn.functional.pad(conv_weights, (1, 1, 1, 1)) # [2, 3, 3, 3]

    avg_pool_weights = 1 / 9 * torch.ones(3, 1, 3, 3) # [3, 1, 3, 3]

    # Using the Associativity property, we can compute a "fused" weight (this can be const folded during compile)
    # This needs to be grouped (group_size=1) because in pooling there is no inter-channel correlations
    fused_weight = torch.nn.functional.conv2d(padded_conv_weights, avg_pool_weights, padding="same", groups=3) # [2, 3, 3, 3]

    # Apply the "fused" weight to the original input
    fused_out = torch.nn.functional.conv2d(x, fused_weight, padding="same")

    # Hopefully the result is equivalent to running the PoolingConv module
    assert torch.allclose(torch_out, fused_out)

    print(torch_out)
    print(fused_out)