rust-lang / rust

Empowering everyone to build reliable and efficient software.
https://www.rust-lang.org
Other
98.33k stars 12.72k forks source link

String::to_lowercase does not get vectorized well contrary to code comments #123712

Closed jhorstmann closed 1 month ago

jhorstmann commented 7 months ago

I'm looking into the performance of to_lowercase / to_uppercaseon mostly ascii strings, using a small microbenchmark added to library/alloc/benches/string.rs.

#[bench]
fn bench_to_lowercase(b: &mut Bencher) {
    let s = "Hello there, the quick brown fox jumped over the lazy dog! \
              Lorem ipsum dolor sit amet, consectetur. ";
    b.iter(|| s.to_lowercase())
}

Using linux perf tooling I see that the hot part of the code is the following large loop, which despite heavy use of sse2 instructions only seems to process 32 bytes per iteration.

       │ d0:┌─→mov        r9,QWORD PTR [r14+r15*1]                                                                                                                                                                                                                                                                         ▒
       │    │  movdqu     xmm3,XMMWORD PTR [r14+r15*1]                                                                                                                                                                                                                                                                     ▒
  0,12 │    │  pshufd     xmm12,xmm3,0xee                                                                                                                                                                                                                                                                                  ▒
  2,56 │    │  movq       rdx,xmm12                                                                                                                                                                                                                                                                                        ▒
       │    │  mov        rsi,rdx                                                                                                                                                                                                                                                                                          ▒
       │    │  or         rsi,r9                                                                                                                                                                                                                                                                                           ▒
  0,59 │    │  test       rsi,rcx                                                                                                                                                                                                                                                                                          ▒
       │    │↓ jne        319                                                                                                                                                                                                                                                                                              ▒
  2,10 │    │  mov        rsi,r9                                                                                                                                                                                                                                                                                           ▒
       │    │  mov        rdi,r9                                                                                                                                                                                                                                                                                           ▒
       │    │  mov        r8,r9                                                                                                                                                                                                                                                                                            ▒
       │    │  mov        r10,r9                                                                                                                                                                                                                                                                                           ▒
  2,23 │    │  shr        r9d,0x8                                                                                                                                                                                                                                                                                          ▒
       │    │  movd       xmm12,r9d                                                                                                                                                                                                                                                                                        ▒
       │    │  shr        r10,0x20                                                                                                                                                                                                                                                                                         ◆
  0,12 │    │  pshufd     xmm13,xmm3,0x44                                                                                                                                                                                                                                                                                  ▒
  1,05 │    │  movdqa     xmm14,xmm3                                                                                                                                                                                                                                                                                       ▒
       │    │  psrlq      xmm14,0x10                                                                                                                                                                                                                                                                                       ▒
       │    │  psrlq      xmm13,0x18                                                                                                                                                                                                                                                                                       ▒
  0,12 │    │  movsd      xmm13,xmm14                                                                                                                                                                                                                                                                                      ▒
  1,75 │    │  movd       xmm14,r10d                                                                                                                                                                                                                                                                                       ▒
       │    │  shr        r8,0x28                                                                                                                                                                                                                                                                                          ▒
       │    │  punpcklqdq xmm3,xmm12                                                                                                                                                                                                                                                                                       ▒
       │    │  movd       xmm12,r8d                                                                                                                                                                                                                                                                                        ▒
  1,52 │    │  andpd      xmm13,xmm0                                                                                                                                                                                                                                                                                       ▒
       │    │  pand       xmm3,xmm0                                                                                                                                                                                                                                                                                        ▒
       │    │  packuswb   xmm3,xmm13                                                                                                                                                                                                                                                                                       ▒
  0,35 │    │  pshufd     xmm13,xmm14,0x50                                                                                                                                                                                                                                                                                 ▒
  1,05 │    │  movdqa     xmm14,XMMWORD PTR [rip+0x403c7]                                                                                                                                                                                                                                                                  ▒
       │    │  pandn      xmm14,xmm13                                                                                                                                                                                                                                                                                      ▒
       │    │  psllq      xmm12,0x28                                                                                                                                                                                                                                                                                       ▒
       │    │  movdqa     xmm13,XMMWORD PTR [rip+0x403c3]                                                                                                                                                                                                                                                                  ▒
  2,94 │    │  pandn      xmm13,xmm12                                                                                                                                                                                                                                                                                      ▒
       │    │  shr        rdi,0x30                                                                                                                                                                                                                                                                                         ▒
       │    │  por        xmm13,xmm14                                                                                                                                                                                                                                                                                      ▒
  0,35 │    │  movd       xmm12,edi                                                                                                                                                                                                                                                                                        ▒
  2,22 │    │  shr        rsi,0x38                                                                                                                                                                                                                                                                                         ▒
       │    │  packuswb   xmm3,xmm1                                                                                                                                                                                                                                                                                        ▒
       │    │  packuswb   xmm3,xmm1                                                                                                                                                                                                                                                                                        ▒
  0,47 │    │  por        xmm13,xmm3                                                                                                                                                                                                                                                                                       ▒
  4,21 │    │  psllq      xmm12,0x30                                                                                                                                                                                                                                                                                       ▒
       │    │  movdqa     xmm3,xmm4                                                                                                                                                                                                                                                                                        ▒
       │    │  pandn      xmm3,xmm12                                                                                                                                                                                                                                                                                       ▒
  0,35 │    │  movd       xmm12,esi                                                                                                                                                                                                                                                                                        ▒
  2,47 │    │  mov        esi,edx                                                                                                                                                                                                                                                                                          ▒
       │    │  shr        esi,0x8                                                                                                                                                                                                                                                                                          ▒
       │    │  pand       xmm13,xmm4                                                                                                                                                                                                                                                                                       ▒
  0,35 │    │  por        xmm3,xmm13                                                                                                                                                                                                                                                                                       ▒
  2,10 │    │  pand       xmm3,xmm5                                                                                                                                                                                                                                                                                        ▒
       │    │  psllq      xmm12,0x38                                                                                                                                                                                                                                                                                       ▒
       │    │  movdqa     xmm13,xmm5                                                                                                                                                                                                                                                                                       ▒
       │    │  pandn      xmm13,xmm12                                                                                                                                                                                                                                                                                      ▒
  2,34 │    │  por        xmm13,xmm3                                                                                                                                                                                                                                                                                       ▒
       │    │  movd       xmm3,edx                                                                                                                                                                                                                                                                                         ▒
       │    │  pshufd     xmm3,xmm3,0x44                                                                                                                                                                                                                                                                                   ▒
  0,53 │    │  movdqa     xmm12,xmm6                                                                                                                                                                                                                                                                                       ▒
  2,47 │    │  pandn      xmm12,xmm3                                                                                                                                                                                                                                                                                       ▒
       │    │  movd       xmm3,esi                                                                                                                                                                                                                                                                                         ▒
       │    │  mov        esi,edx                                                                                                                                                                                                                                                                                          ▒
  0,23 │    │  shr        esi,0x10                                                                                                                                                                                                                                                                                         ▒
  2,64 │    │  pand       xmm13,xmm6                                                                                                                                                                                                                                                                                       ▒
       │    │  por        xmm12,xmm13                                                                                                                                                                                                                                                                                      ▒
       │    │  pslldq     xmm3,0x9                                                                                                                                                                                                                                                                                         ▒
  0,12 │    │  movdqa     xmm13,xmm7                                                                                                                                                                                                                                                                                       ▒
  2,45 │    │  pandn      xmm13,xmm3                                                                                                                                                                                                                                                                                       ▒
       │    │  movd       xmm3,esi                                                                                                                                                                                                                                                                                         ▒
       │    │  mov        esi,edx                                                                                                                                                                                                                                                                                          ▒
  0,51 │    │  shr        esi,0x18                                                                                                                                                                                                                                                                                         ▒
  2,60 │    │  pand       xmm12,xmm7                                                                                                                                                                                                                                                                                       ▒
       │    │  por        xmm13,xmm12                                                                                                                                                                                                                                                                                      ▒
       │    │  pslldq     xmm3,0xa                                                                                                                                                                                                                                                                                         ▒
       │    │  movdqa     xmm12,xmm8                                                                                                                                                                                                                                                                                       ▒
  1,76 │    │  pandn      xmm12,xmm3                                                                                                                                                                                                                                                                                       ▒
       │    │  movd       xmm3,esi                                                                                                                                                                                                                                                                                         ▒
       │    │  mov        rsi,rdx                                                                                                                                                                                                                                                                                          ▒
  0,47 │    │  shr        rsi,0x20                                                                                                                                                                                                                                                                                         ▒
  2,34 │    │  pand       xmm13,xmm8                                                                                                                                                                                                                                                                                       ▒
       │    │  por        xmm12,xmm13                                                                                                                                                                                                                                                                                      ▒
       │    │  pslldq     xmm3,0xb                                                                                                                                                                                                                                                                                         ▒
  0,23 │    │  movdqa     xmm13,xmm9                                                                                                                                                                                                                                                                                       ▒
  1,99 │    │  pandn      xmm13,xmm3                                                                                                                                                                                                                                                                                       ▒
       │    │  movd       xmm3,esi                                                                                                                                                                                                                                                                                         ▒
       │    │  mov        rsi,rdx                                                                                                                                                                                                                                                                                          ▒
  0,35 │    │  shr        rsi,0x28                                                                                                                                                                                                                                                                                         ▒
  2,97 │    │  pand       xmm12,xmm9                                                                                                                                                                                                                                                                                       ▒
       │    │  por        xmm13,xmm12                                                                                                                                                                                                                                                                                      ▒
       │    │  pshufd     xmm3,xmm3,0x0                                                                                                                                                                                                                                                                                    ▒
  0,12 │    │  movdqa     xmm12,xmm10                                                                                                                                                                                                                                                                                      ▒
  2,11 │    │  pandn      xmm12,xmm3                                                                                                                                                                                                                                                                                       ▒
       │    │  movd       xmm3,esi                                                                                                                                                                                                                                                                                         ▒
       │    │  shr        rdx,0x30                                                                                                                                                                                                                                                                                         ▒
       │    │  pand       xmm13,xmm10                                                                                                                                                                                                                                                                                      ▒
  1,87 │    │  por        xmm12,xmm13                                                                                                                                                                                                                                                                                      ▒
       │    │  pand       xmm12,xmm11                                                                                                                                                                                                                                                                                      ▒
       │    │  pslldq     xmm3,0xd                                                                                                                                                                                                                                                                                         ▒
  0,23 │    │  movdqa     xmm13,xmm11                                                                                                                                                                                                                                                                                      ▒
  2,23 │    │  pandn      xmm13,xmm3                                                                                                                                                                                                                                                                                       ▒
       │    │  por        xmm13,xmm12                                                                                                                                                                                                                                                                                      ▒
       │    │  pand       xmm13,XMMWORD PTR [rip+0x40320]                                                                                                                                                                                                                                                                  ▒
  0,12 │    │  movd       xmm3,edx                                                                                                                                                                                                                                                                                         ▒
  2,80 │    │  pslldq     xmm3,0xe                                                                                                                                                                                                                                                                                         ▒
       │    │  por        xmm3,xmm13                                                                                                                                                                                                                                                                                       ▒
       │    │  pand       xmm3,XMMWORD PTR [rip+0x4031a]                                                                                                                                                                                                                                                                   ▒
  0,23 │    │  movzx      edx,BYTE PTR [r14+r15*1+0xf]                                                                                                                                                                                                                                                                     ▒
  3,31 │    │  movd       xmm12,edx                                                                                                                                                                                                                                                                                        ▒
       │    │  pslldq     xmm12,0xf                                                                                                                                                                                                                                                                                        ▒
       │    │  por        xmm12,xmm3                                                                                                                                                                                                                                                                                       ▒
  0,12 │    │  movdqa     xmm3,xmm12                                                                                                                                                                                                                                                                                       ▒
  2,92 │    │  paddb      xmm3,XMMWORD PTR [rip+0x31d97]        # 1009a0 <anon.cf73386a2f5127d166baeac25be116f0.63.llvm.16014458289627072720+0x459>                                                                                                                                                                        ▒
       │    │  movdqa     xmm13,xmm3                                                                                                                                                                                                                                                                                       ▒
       │    │  pminub     xmm13,xmm15                                                                                                                                                                                                                                                                                      ▒
  0,47 │    │  pcmpeqb    xmm13,xmm3                                                                                                                                                                                                                                                                                       ▒
  1,53 │    │  pand       xmm13,xmm2                                                                                                                                                                                                                                                                                       ▒
  0,36 │    │  por        xmm13,xmm12                                                                                                                                                                                                                                                                                      ▒
       │    │  movdqu     XMMWORD PTR [rax+r15*1],xmm13                                                                                                                                                                                                                                                                    ▒
  0,23 │    │  lea        rdx,[r15+0x10]                                                                                                                                                                                                                                                                                   ▒
  2,34 │    │  add        r15,0x20                                                                                                                                                                                                                                                                                         ▒
  0,12 │    │  cmp        r15,rbx                                                                                                                                                                                                                                                                                          ▒
       │    │  mov        r15,rdx                                                                                                                                                                                                                                                                                          ▒
  1,64 │    └──jbe        d0             

I don't see an easy way to improve the autovectorization of this code, but it should be relatively easy to explicitly vectorize it using portable_simd, and I would like to prepare such a PR if there are no objections. As far as I know, portable_simd is already in use inside core, for example by #103779.

the8472 commented 7 months ago

The way the loop ors the bits for the is-ascii check into a single usize doesn't seem ideal for vectorization. keeping the lanes independent should yield better autovectorization results.

jhorstmann commented 7 months ago

@the8472 Thanks for the suggestion. I already started with a simd implementation, but now checked again if the autovectorization could be improved. With the following main loop

    while slice.len() >= N {
        let chunk = &slice[..N];

        let mut is_ascii = true;
        for j in 0..N {
            is_ascii &= chunk[j] <= 127;
        }
        if !is_ascii {
            break;
        }

        for j in 0..N {
            out_slice[j] = MaybeUninit::new(convert(&chunk[j]));
        }

        i += N;
        slice = &slice[N..];
        out_slice = &mut out_slice[N..];
    }

The assembly and performance is indeed better, but there is still some weird shuffling and shifting going on:

  0,08 │ 80:┌─→movdqu  xmm3,XMMWORD PTR [rbx+rax*1]                                                                                        ▒
  0,07 │    │  pshufd  xmm4,xmm3,0xee                                                                                                      ▒
  0,37 │    │  por     xmm4,xmm3                                                                                                           ▒
  6,22 │    │  pshufd  xmm5,xmm4,0x55                                                                                                      ▒
  0,16 │    │  por     xmm5,xmm4                                                                                                           ▒
       │    │  movdqa  xmm4,xmm5                                                                                                           ▒
  2,02 │    │  psrld   xmm4,0x10                                                                                                           ▒
  4,48 │    │  por     xmm4,xmm5                                                                                                           ▒
  0,07 │    │  movdqa  xmm5,xmm4                                                                                                           ▒
  0,08 │    │  psrlw   xmm5,0x8                                                                                                            ▒
  0,60 │    │  por     xmm5,xmm4                                                                                                           ▒
  6,58 │    │  movd    ecx,xmm5                                                                                                            ▒
       │    │  test    cl,cl                                                                                                               ▒
       │    │↓ js      ef                                                                                                                  ▒
       │    │  movdqa  xmm4,xmm3                                                                                                           ▒
  1,80 │    │  paddb   xmm4,xmm0                                                                                                           ▒
  4,70 │    │  movdqa  xmm5,xmm4                                                                                                           ▒
       │    │  pminub  xmm5,xmm1                                                                                                           ▒
       │    │  pcmpeqb xmm5,xmm4                                                                                                           ▒
  0,37 │    │  pand    xmm5,xmm2                                                                                                           ▒
  8,50 │    │  por     xmm5,xmm3                                                                                                           ▒
       │    │  movdqu  XMMWORD PTR [r13+rax*1+0x0],xmm5                                                                                    ▒
       │    │  add     rax,0x10                                                                                                            ▒
  1,94 │    │  add     r12,0xfffffffffffffff0                                                                                              ▒
  1,05 │    ├──cmp     r12,0xf                                                                                                             ▒
  4,19 │    └──ja      80        

The explicit simd version looks a bit better, mostly because the ascii check directly translates to a pmovmskb:

        const LANES: usize = 16;

        let simd_range_start = Simd::splat(range_start);
        let simd_range_end = Simd::splat(range_end);
        let simd_xor_value = Simd::splat(xor_value);

        while slice.len() >= LANES {
            let chunk = Simd::<u8, LANES>::from_slice(slice);
            let is_ascii = chunk.cast::<i8>().simd_ge(Simd::splat(0));

            if is_ascii.all() {
                let is_in_range = chunk.simd_ge(simd_range_start) & chunk.simd_le(simd_range_end);
                let converted = is_in_range.select(chunk ^ simd_xor_value, chunk);
                // SAFETY: output has enough capacity and we never read the uninitialized slice
                unsafe {
                    let out_slice = core::slice::from_raw_parts_mut(out_ptr, LANES);
                    converted.copy_to_slice(out_slice);
                    out_ptr = out_ptr.add(LANES);
                }
                i += LANES;
                slice = &slice[LANES..];
            } else {
                break;
            }
        }
       │ 80:┌─→movdqu   xmm3,XMMWORD PTR [r14+rcx*1]                                                                                       ▒
  0,12 │    │  pmovmskb edx,xmm3                                                                                                           ▒
  0,12 │    │  test     edx,edx                                                                                                            ▒
       │    │↓ jne      d4                                                                                                                 ▒
  5,82 │    │  movdqa   xmm4,xmm3                                                                                                          ▒
  0,04 │    │  paddb    xmm4,xmm0                                                                                                          ▒
       │    │  movdqa   xmm5,xmm4                                                                                                          ▒
  0,08 │    │  pminub   xmm5,xmm1                                                                                                          ▒
 10,05 │    │  pcmpeqb  xmm5,xmm4                                                                                                          ▒
       │    │  movdqa   xmm4,xmm5                                                                                                          ▒
       │    │  pandn    xmm4,xmm3                                                                                                          ▒
       │    │  pxor     xmm3,xmm2                                                                                                          ▒
  7,02 │    │  pand     xmm3,xmm5                                                                                                          ▒
  0,04 │    │  por      xmm3,xmm4                                                                                                          ▒
       │    │  movdqu   XMMWORD PTR [rax+rcx*1],xmm3                                                                                       ▒
       │    │  add      rcx,0x10                                                                                                           ▒
  9,69 │    │  mov      rdx,rsi                                                                                                            ▒
       │    │  add      rdx,0xfffffffffffffff0                                                                                             ▒
       │    │  mov      rsi,rdx                                                                                                            ▒
       │    ├──cmp      rdx,0xf                                                                                                            ▒
  5,03 │    └──ja       80                 

Do you have a preference here between autovectorization and explicit simd?

the8472 commented 7 months ago

If the SIMD impl results in reasonable code across architectures, including some without SIMD then that should be fine. I think it would be the first time where we use portable simd unconditionally, so it could use some additional scrutiny.

If it's only meant to target x86 with SSE2 then that'd mean having multiple paths and in that case we'd still have to tweak the generic path anyway. At that point we might as well just optimize the latter further.

The assembly and performance is indeed better, but there is still some weird shuffling and shifting going on:

You can probably get rid of those too by keeping multiple bools in a small array. That way the optimizer will more easily see that it can shove them into independent simd lanes. Increasing the unroll count might help too to fit better into simd registers.

We do get pretty decent autovectorization in other places in the standard library by massaging things into a state that basically looks like arrays-instead-of-SIMD-types. E.g.

https://github.com/rust-lang/rust/blob/40cf1f9257628d40e38a4eca9e1b8ea03a3abcd1/library/core/src/str/iter.rs#L61-L73

jhorstmann commented 7 months ago

You can probably get rid of those too by keeping multiple bools in a small array. That way the optimizer will more easily see that it can shove them into independent simd lanes.

It seems llvm is too "smart" for this trick, on its own that generates the exact same code. The sum trick does work, although it seems a bit fragile. The following creates a pmovmskb instruction just like the simd version, but a slightly different alternative with a non_ascii array and comparing against 0 does not.

        let mut is_ascii = [false; N];

        for j in 0..N {
            is_ascii[j] = chunk[j] <= 127;
        }

        if is_ascii.into_iter().map(|x| x as u8).sum::<u8>() as usize != N {
            break;
        }

If the SIMD impl results in reasonable code across architectures, including some without SIMD then that should be fine.

It only needs to be better than the autovectorized version on those platforms ;)

If it's only meant to target x86 with SSE2 then that'd mean having multiple paths and in that case we'd still have to tweak the generic path anyway. At that point we might as well just optimize the latter further.

My initial idea was to gate the simd code on either SSE2 or Neon (possibly also Altivec and RiscV). I'd also add a scalar loop for the non-multiple-of-N remainder, so all-ascii strings are fully handled by specialized code. Currently this remainder goes through the generic char_indices and conversions::to_lower code path. On other platforms, only the scalar loop would be used, since I assume the autovectorization would also generate suboptimal code.

But I agree that if similar code quality can be achieved with autovectorization, that would be preferable. I'll open a PR after a little bit more polishing the code.