llvm / llvm-project

The LLVM Project is a collection of modular and reusable compiler and toolchain technologies.
http://llvm.org
Other
29.17k stars 12.03k forks source link

[AVX2+] Vectorized `1 << u3` in a byte vector should turn into `vpshufb` #110317

Closed Validark closed 1 month ago

Validark commented 1 month ago

This code: (Godbolt link)

export fn foo(chunk: @Vector(32, u8)) @TypeOf(chunk) {
    return @as(@TypeOf(chunk), @splat(1)) << @truncate(chunk);
}
define dso_local range(i8 1, -127) <32 x i8> @foo(<32 x i8> %0) local_unnamed_addr {
Entry:
  %1 = and <32 x i8> %0, <i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7>
  %2 = shl nuw <32 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>, %1
  ret <32 x i8> %2
}

Compiles like so for Zen 3:

.LCPI0_1:
        .zero   32,16
.LCPI0_2:
        .zero   32,252
.LCPI0_3:
        .zero   32,224
.LCPI0_4:
        .byte   1
foo:
        vpsllw  ymm0, ymm0, 5
        vpbroadcastb    ymm1, byte ptr [rip + .LCPI0_4]
        vpblendvb       ymm1, ymm1, ymmword ptr [rip + .LCPI0_1], ymm0
        vpand   ymm0, ymm0, ymmword ptr [rip + .LCPI0_3]
        vpsllw  ymm2, ymm1, 2
        vpand   ymm2, ymm2, ymmword ptr [rip + .LCPI0_2]
        vpaddb  ymm0, ymm0, ymm0
        vpblendvb       ymm1, ymm1, ymm2, ymm0
        vpaddb  ymm0, ymm0, ymm0
        vpaddb  ymm2, ymm1, ymm1
        vpblendvb       ymm0, ymm1, ymm2, ymm0
        ret

However, because the bytes resulting from @truncate(chunk) are in the range [0, 7], we can precompute all 8 possible answers and use vpshufb instead (Godbolt, full code):

export fn foo2(chunk: @Vector(32, u8)) @TypeOf(chunk) {
    const table = comptime foo(std.simd.repeat(@sizeOf(@TypeOf(chunk)), std.simd.iota(u8, 16)));
    return vpshufb(table, @as(@Vector(32, u3), @truncate(chunk)));
}

fn vpshufb(table: anytype, indices: @TypeOf(table)) @TypeOf(table) {
    if (@inComptime()) {
        var result: @TypeOf(indices) = undefined;
        for (0..@bitSizeOf(@TypeOf(indices)) / 8) |i| {
            const index = indices[i];
            result[i] = if (index >= 0x80) 0 else table[index % (@bitSizeOf(@TypeOf(table)) / 8)];
        }

        return result;
    }

    const methods = struct {
        extern fn @"llvm.x86.avx512.pshuf.b.512"(@Vector(64, u8), @Vector(64, u8)) @Vector(64, u8);
        extern fn @"llvm.x86.avx2.pshuf.b"(@Vector(32, u8), @Vector(32, u8)) @Vector(32, u8);
        extern fn @"llvm.x86.ssse3.pshuf.b.128"(@Vector(16, u8), @Vector(16, u8)) @Vector(16, u8);
    };

    return switch (@TypeOf(table)) {
        @Vector(64, u8) => if (comptime std.Target.x86.featureSetHas(builtin.cpu.features, .avx512bw)) methods.@"llvm.x86.avx512.pshuf.b.512"(table, indices) else @compileError("CPU target lacks support for vpshufb512"),
        @Vector(32, u8) => if (comptime std.Target.x86.featureSetHas(builtin.cpu.features, .avx2)) methods.@"llvm.x86.avx2.pshuf.b"(table, indices) else @compileError("CPU target lacks support for vpshufb256"),
        @Vector(16, u8) => if (comptime std.Target.x86.featureSetHas(builtin.cpu.features, .ssse3)) methods.@"llvm.x86.ssse3.pshuf.b.128"(table, indices) else @compileError("CPU target lacks support for vpshufb128"),
        else => @compileError(std.fmt.comptimePrint("Invalid argument type passed to vpshufb: {}\n", .{@TypeOf(table)})),
    };
}
.LCPI0_0:
        .zero   32,7
# Removed dead vector data. See https://github.com/llvm/llvm-project/issues/110305
.LCPI0_2:
        .byte   1
        .byte   2
        .byte   4
        .byte   8
        .byte   16
        .byte   32
        .byte   64
        .byte   128
foo2:
        vpand   ymm0, ymm0, ymmword ptr [rip + .LCPI0_0]
        vpbroadcastq    ymm1, qword ptr [rip + .LCPI0_2]
        vpshufb ymm0, ymm1, ymm0
        ret
llvmbot commented 1 month ago

@llvm/issue-subscribers-backend-x86

Author: Niles Salter (Validark)

This code: ([Godbolt link](https://zig.godbolt.org/#g:!((g:!((g:!((h:codeEditor,i:(filename:'1',fontScale:14,fontUsePx:'0',j:3,lang:zig,selection:(endColumn:2,endLineNumber:3,positionColumn:2,positionLineNumber:3,selectionStartColumn:2,selectionStartLineNumber:3,startColumn:2,startLineNumber:3),source:'export+fn+foo(chunk:+@Vector(32,+u8))+@TypeOf(chunk)+%7B%0A++++return+@as(@TypeOf(chunk),+@splat(1))+%3C%3C+@truncate(chunk+%3E%3E+@splat(4))%3B%0A%7D'),l:'5',n:'1',o:'Zig+source+%233',t:'0')),header:(),k:50.61449749453956,l:'4',m:100,n:'0',o:'',s:0,t:'0'),(g:!((h:compiler,i:(compiler:ztrunk,filters:(b:'0',binary:'1',binaryObject:'1',commentOnly:'0',debugCalls:'1',demangle:'0',directives:'0',execute:'1',intel:'0',libraryCode:'0',trim:'1',verboseDemangling:'0'),flagsViewOpen:'1',fontScale:14,fontUsePx:'0',j:1,lang:zig,libs:!(),options:'-O+ReleaseFast+-target+x86_64-linux+-mcpu%3Dznver3',overrides:!(),selection:(endColumn:33,endLineNumber:18,positionColumn:33,positionLineNumber:18,selectionStartColumn:33,selectionStartLineNumber:18,startColumn:33,startLineNumber:18),source:3),l:'5',n:'0',o:'+zig+trunk+(Editor+%233)',t:'0')),header:(),k:49.38550250546045,l:'4',m:100,n:'0',o:'',s:0,t:'0')),l:'2',n:'0',o:'',t:'0')),version:4)) ```zig export fn foo(chunk: @Vector(32, u8)) @TypeOf(chunk) { return @as(@TypeOf(chunk), @splat(1)) << @truncate(chunk >> @splat(4)); } ``` ```llvm define dso_local range(i8 1, -127) <32 x i8> @foo(<32 x i8> %0) local_unnamed_addr { Entry: %1 = lshr <32 x i8> %0, <i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4> %2 = and <32 x i8> %1, <i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7> %3 = shl nuw <32 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>, %2 ret <32 x i8> %3 } ``` Compiles like so for Zen 3: ```asm .LCPI0_1: .zero 32,16 .LCPI0_2: .zero 32,252 .LCPI0_3: .zero 32,224 .LCPI0_4: .byte 1 foo: vpsllw ymm0, ymm0, 1 vpbroadcastb ymm1, byte ptr [rip + .LCPI0_4] vpblendvb ymm1, ymm1, ymmword ptr [rip + .LCPI0_1], ymm0 vpand ymm0, ymm0, ymmword ptr [rip + .LCPI0_3] vpsllw ymm2, ymm1, 2 vpand ymm2, ymm2, ymmword ptr [rip + .LCPI0_2] vpaddb ymm0, ymm0, ymm0 vpblendvb ymm1, ymm1, ymm2, ymm0 vpaddb ymm0, ymm0, ymm0 vpaddb ymm2, ymm1, ymm1 vpblendvb ymm0, ymm1, ymm2, ymm0 ret ``` However, because the bytes resulting from `chunk >> @splat(4)` are in the range [0, 15], we can precompute all 16 possible answers and use vpshufb instead ([Godbolt, full code](https://zig.godbolt.org/#z:OYLghAFBqd5QCxAYwPYBMCmBRdBLAF1QCcAaPECAMzwBtMA7AQwFtMQByARg9KtQYEAysib0QXACx8BBAKoBnTAAUAHpwAMvAFYgAzKVpMGoAF55gpJfWQE8Ayo3QBhVLQCuLBiABMpJwAyeAyYAHKeAEaYxPoAbKQADqgKhPYMrh5eEonJqQJBIeEsUTEAHFaYNnYCQgRMxAQZnt5%2B1pi2abX1BAVhkdFxVnUNTVlcQ929RSUgpQCUVqjuxMjsHGgMCgQA1Fvo2wCkegAi2wACeCxJDRAHPj57dz5zRwBCBxoAghtb2xHudDsDEOJ3Ol2uBFu93%2BgOCTxeeneXw%2BnyowP4qAgyAQ7gYAGsQOcAGrtIjECB6PzbdzzObnAAqAE8EpgAPJULE4/F0g4AdiRn22Qu2xEwBGWwLOTAUEDOTJZ7M5uLxC3OCgSRkhXDmPL0ziOznOBGIuNEBEwSvxIOwR2wao1TEhkh1bxRfOObq%2BmFUEO2aL9qFQPktBOJpJIFKpNJ1DOZbI52OVPP5KOF2x%2BOzqEXoINOaCudjYAcxewAdCkWOhS6KWY7ZSlTPHZfKm4nuaqyxWq/Y6hAaaRtlxYrr9Xr7ZqIM6XYjU8LReLiMCAG4JBQ4qgRCBZ%2BgDtt46228d1qeu5G8j3Ir7%2Bldr9wbrdMbPsbbGRkEOMD4L4VYKQlyuOKtumAxv%2BCockByYCmmeBUNssrBK4BaXBaMZ8lBabCku9QipgCjuLQBB/i2ipfngP66qcuJYDQIToKegoYUK/DEHBGilqWZwRIQQh4I2iqgU2pHkXSAD02zzIcvLOHgfKGmhs6MWmGbbF%2B3q5ipDDfrhBwAKyvDJOkejOXyKRhop4QRun6bppxHKcMFwapqgHqCGiqKUGh0ho2yVEo2xAVZTmHD4OlwZx3G8U2AmAY%2B9AxmJ8w2fRpnup6DGMfOEo4RZBDJcKqWXul6YCL8bAEAgGAKOpWwmrYknoYx3rmoufqSk8tC0EuLClu5sSlkwS6qDpXA%2BKWq7rqWESlsNPhPLKJK2BGsTSNS8wDmcC1khAy0DtGdIbeG5I7atCINRhTXROibX3B1XU9aUfUDaoo3jXek1zQdi3kpSu1rWGX2Rr9IGbRGP0nXljWqM1V3nO1nXdb15YKEoehjbeVCTaWI2lB9IPkkOQPrXjEAEyd%2B3E6T0YQ5JRkCgpOELsCCgAO6ENiYXEeBsXAZB9Npp9W3HdGuaHg5WKoEhRZlvS9TAGK919VQmCOssmBCGKAAS0oQDCBHBKWyAJO4pZKyr5kDv1g0zREzMxmVFXoAoHFw3diNPTNaMTVNM1zUBn6aWRuF0r5mDnPmCR0DgxDEBGTzOMocj%2BbLYrbEYyB4lVeEJL6zHbDe64RD79wLHzwoC6DUYSXZouweLkuh9LycEArJvKwuaua9rutAgbRut2buEW09zx0vblXOzd8Mt8PntvREvvc/7WkKMHtB%2BWc4eR9g0ex/c8eJ8Mcs7GnGe7O42ckDsuf53eEQhbE8KkKXQrl/j8QnSLeh2mL4eFg3BBVhlsQI%2BLdTbt3VgQLWMpu760NsbMBqsFAW2RijO2YoHZOzOC7BGD0kYo1nhjKa2MF5PiXoHFePk16hw3hLCO9Bt4x3JHHBOSdgEpxPpnc%2BOcSB51ehuYhxcn4mUUiHT%2BdoaFXC3jvckZYqAsGbr/ZCyhiDBEhE8AAkgwLCtA8D7Flp4RgmY4zbASNKJQ%2BwiC8PRhEQk8lzy6WcAwJ4Fs0LRS5k%2BF454dRCKKqlYynx3QcAWLQTgOleDeG4LwVAnAABaFhdhLBWKHO4eg9C8EIhwLQ3iEDKywDECACwCQ6TYj4XkeguAaAAJzVN5DpMpUhpAhI4JIXgLAJAaA0KQCJWhSDRI4LwX8nSMlZNIHAWASBN70DIBQOudCBimGNMqPggJoi/h1poXgXFmDEEZJwHgpAtn1EZKyCI2hSR7N4PmNgghWQMFoLszJvAsD/GAM4MQa8LmkCwCwYwwBxCPK%2BXgUUHQly4Q2f4VQ7R3Dmk%2Baoyo4KdERGIEc1wWBwXGkuJ80FxAIjJEwMcTAPyTA6JMBshYVAjDAAUESPAmBmashZBE/Z/BBAiDEOwKQMhBCKBUOoAFuhxhGFJeYSwiLfyQAWKgBI1RNicAALSsm2AAJUqMrJQAAxaUOw5WHxTr1AA%2BstOVJL3DOTlSwOBdlTBaOiGk3p2KVFYHFQUioVQ0iOE0qMbwXByiBGCH0YoAwuAGCSCkGVXqJDlFDXkBgUx%2BgxGDa60knRhiNDcM0SNSaOg1FTXGwNCaDBbG6BGn1EwGh5pmMGhYChEmrH0ME0J4TwV9O2CK/yJorQQFwIQHhKS9BzHSWShYOSmB5MoIUkAkhKmll5D4WIpRKkdI0D4SQc7Si8kMJwFppA2kVM6d0qJnABkgCGUOzdHAfBNoBX0wdjzvHYpSA4SQQA%3D)): ```zig export fn foo2(chunk: @Vector(32, u8)) @TypeOf(chunk) { const table = comptime foo(std.simd.repeat(@sizeOf(@TypeOf(chunk)), std.simd.iota(u8, 16) << @splat(4))); return vpshufb(table, chunk >> @splat(4)); } fn vpshufb(table: anytype, indices: @TypeOf(table)) @TypeOf(table) { if (@inComptime()) { var result: @TypeOf(indices) = undefined; for (0..@bitSizeOf(@TypeOf(indices)) / 8) |i| { const index = indices[i]; result[i] = if (index >= 0x80) 0 else table[index % (@bitSizeOf(@TypeOf(table)) / 8)]; } return result; } const methods = struct { extern fn @"llvm.x86.avx512.pshuf.b.512"(@Vector(64, u8), @Vector(64, u8)) @Vector(64, u8); extern fn @"llvm.x86.avx2.pshuf.b"(@Vector(32, u8), @Vector(32, u8)) @Vector(32, u8); extern fn @"llvm.x86.ssse3.pshuf.b.128"(@Vector(16, u8), @Vector(16, u8)) @Vector(16, u8); }; return switch (@TypeOf(table)) { @Vector(64, u8) => if (comptime std.Target.x86.featureSetHas(builtin.cpu.features, .avx512bw)) methods.@"llvm.x86.avx512.pshuf.b.512"(table, indices) else @compileError("CPU target lacks support for vpshufb512"), @Vector(32, u8) => if (comptime std.Target.x86.featureSetHas(builtin.cpu.features, .avx2)) methods.@"llvm.x86.avx2.pshuf.b"(table, indices) else @compileError("CPU target lacks support for vpshufb256"), @Vector(16, u8) => if (comptime std.Target.x86.featureSetHas(builtin.cpu.features, .ssse3)) methods.@"llvm.x86.ssse3.pshuf.b.128"(table, indices) else @compileError("CPU target lacks support for vpshufb128"), else => @compileError(std.fmt.comptimePrint("Invalid argument type passed to vpshufb: {}\n", .{@TypeOf(table)})), }; } ``` ```asm .LCPI0_0: .zero 32,15 .byte 1 .byte 2 .byte 4 .byte 8 .byte 16 .byte 32 .byte 64 .byte 128 .byte 1 .byte 2 .byte 4 .byte 8 .byte 16 .byte 32 .byte 64 .byte 128 .byte 1 .byte 2 .byte 4 .byte 8 .byte 16 .byte 32 .byte 64 .byte 128 .byte 1 .byte 2 .byte 4 .byte 8 .byte 16 .byte 32 .byte 64 .byte 128 .LCPI0_2: .byte 1 .byte 2 .byte 4 .byte 8 .byte 16 .byte 32 .byte 64 .byte 128 foo2: vpsrlw ymm0, ymm0, 4 vpand ymm0, ymm0, ymmword ptr [rip + .LCPI0_0] vpbroadcastq ymm1, qword ptr [rip + .LCPI0_2] vpshufb ymm0, ymm1, ymm0 ret ```
RKSimon commented 1 month ago

The AND to clamp the shift amount is irrelevant for the SHL -> PSHUFB lowering as anything out of bounds would be poison anyway. All we need is to be shifting a vXi8 splat constant for this to work.

Validark commented 1 month ago

The AND to clamp the shift amount is irrelevant for the SHL -> PSHUFB lowering as anything out of bounds would be poison anyway. All we need is to be shifting a vXi8 splat constant for this to work.

I included it because you can't really do an out-of-bounds shift in Zig. You have to do a @truncate which gives you the lower log_2(int) bits or an @intCast which is a promise that it's already truncated, and oftentimes an AND gets inserted anyway.