iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.82k stars 609 forks source link

RHS tensor.pack op codegen generates inefficient store #16313

Open hanhanW opened 9 months ago

hanhanW commented 9 months ago
    14ed:   62 f1 fd 48 6f 05 89    vmovdqa64 zmm0,ZMMWORD PTR [rip+0xffffffffffffee89]        # 380 <__unnamed_1-0xb0>
    14f4:   ee ff ff 
    14f7:   62 f1 fd 48 6f 0d bf    vmovdqa64 zmm1,ZMMWORD PTR [rip+0xffffffffffffeebf]        # 3c0 <__unnamed_1-0x70>
    14fe:   ee ff ff 
    1501:   c4 e2 7d 5a 15 16 ef    vbroadcasti128 ymm2,XMMWORD PTR [rip+0xffffffffffffef16]        # 420 <__unnamed_1-0x10>
    1508:   ff ff 
    150a:   eb 17                   jmp    1523 <pack_i8_dispatch_0_pack_i8+0x93>
    150c:   0f 1f 40 00             nop    DWORD PTR [rax+0x0]
    1510:   49 ff c3                inc    r11
    1513:   48 83 c7 10             add    rdi,0x10
    1517:   4d 01 c8                add    r8,r9
    151a:   49 39 c3                cmp    r11,rax
    151d:   0f 84 a0 00 00 00       je     15c3 <pack_i8_dispatch_0_pack_i8+0x133>
    1523:   48 85 f6                test   rsi,rsi
    1526:   7e e8                   jle    1510 <pack_i8_dispatch_0_pack_i8+0x80>
    1528:   4d 89 de                mov    r14,r11
    152b:   49 c1 e6 04             shl    r14,0x4
    152f:   49 89 d7                mov    r15,rdx
    1532:   4d 29 f7                sub    r15,r14
    1535:   49 83 ff 10             cmp    r15,0x10
    1539:   4c 0f 4d fb             cmovge r15,rbx
    153d:   62 d2 fd 48 7c df       vpbroadcastq zmm3,r15
    1543:   4d 89 c6                mov    r14,r8
    1546:   49 89 cf                mov    r15,rcx
    1549:   49 89 fc                mov    r12,rdi
    154c:   49 89 f5                mov    r13,rsi
    154f:   eb 48                   jmp    1599 <pack_i8_dispatch_0_pack_i8+0x109>
    1551:   66 66 66 66 66 66 2e    data16 data16 data16 data16 data16 cs nop WORD PTR [rax+rax*1+0x0]
    1558:   0f 1f 84 00 00 00 00 
    155f:   00 
    1560:   62 d1 7f 8a 6f 24 24    vmovdqu8 xmm4{k2}{z},XMMWORD PTR [r12]
    1567:   62 d1 7f 89 6f 2c 14    vmovdqu8 xmm5{k1}{z},XMMWORD PTR [r12+rdx*1]
    156e:   c4 e3 5d 38 e5 01       vinserti128 ymm4,ymm4,xmm5,0x1
    1574:   c4 e3 fd 00 e4 d8       vpermq ymm4,ymm4,0xd8
    157a:   c4 e2 5d 00 e2          vpshufb ymm4,ymm4,ymm2
    157f:   c4 c1 7d 7f 66 e2       vmovdqa YMMWORD PTR [r14-0x1e],ymm4
    1585:   4d 01 d4                add    r12,r10
    1588:   49 83 c7 fe             add    r15,0xfffffffffffffffe
    158c:   49 83 c6 20             add    r14,0x20
    1590:   49 ff cd                dec    r13
    1593:   0f 84 77 ff ff ff       je     1510 <pack_i8_dispatch_0_pack_i8+0x80>
    1599:   62 f2 e5 48 37 c0       vpcmpgtq k0,zmm3,zmm0
    159f:   62 f2 e5 48 37 c9       vpcmpgtq k1,zmm3,zmm1
    15a5:   c5 f5 4b c8             kunpckbw k1,k1,k0
    15a9:   c4 e1 f8 90 d1          kmovq  k2,k1
    15ae:   4d 85 ff                test   r15,r15
    15b1:   7f 04                   jg     15b7 <pack_i8_dispatch_0_pack_i8+0x127>
    15b3:   c5 fc 47 d0             kxorw  k2,k0,k0
    15b7:   49 83 ff 02             cmp    r15,0x2
    15bb:   7d a3                   jge    1560 <pack_i8_dispatch_0_pack_i8+0xd0>
    15bd:   c5 fc 47 c8             kxorw  k1,k0,k0
    15c1:   eb 9d                   jmp    1560 <pack_i8_dispatch_0_pack_i8+0xd0>
    15c3:   31 c0                   xor    eax,eax
    15c5:   5b                      pop    rbx
    15c6:   41 5c                   pop    r12
    15c8:   41 5d                   pop    r13
    15ca:   41 5e                   pop    r14
    15cc:   41 5f                   pop    r15
    15ce:   5d                      pop    rbp
    15cf:   c5 f8 77                vzeroupper 
    15d2:   c3                      ret    
    15d3:   cc                      int3   
    15d4:   cc                      int3   

The VectorTransferLowering generates inefficient vector.store because the innermost dim is 2xi8. It is fully unrolled.

// 16 below patterns
%58 = vector.extract %43[14] : vector<2xi8> from vector<16x2xi8>
vector.store %58, %subview_2[%c14, %c0] : memref<16x2xi8, affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 + s0)>, #hal.descriptor_type<storage_buffer>>, vector<2xi8>

The potential fix is to flatten innermost dims (with memref.collapse_shape) up to vector length. We have some support upstream, but need to add the control on the patterns. Otherwise, it generates a big 1D vector.

To repro: run iree-compile --output-format=vm-bytecode --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=cascadelake --iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu ~/repro.mlir -o /tmp/a.vmfb

func.func @pack_i8(%source: tensor<?x?xi8>) -> tensor<?x?x16x2xi8> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %d0 = tensor.dim %source, %c0 : tensor<?x?xi8>
  %d1 = tensor.dim %source, %c1 : tensor<?x?xi8>
  %c16 = arith.constant 16 : index
  %c2 = arith.constant 2 : index
  %tiled_d0 = arith.ceildivui %d0, %c2 : index
  %tiled_d1 = arith.ceildivui %d1, %c16 : index

  %zero = arith.constant 0 : i8
  %init_pack = tensor.empty(%tiled_d1, %tiled_d0) : tensor<?x?x16x2xi8>
  %pack = tensor.pack %source
    padding_value(%zero: i8)
    outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 2]
    into %init_pack : tensor<?x?xi8> -> tensor<?x?x16x2xi8>
  return %pack : tensor<?x?x16x2xi8>
}
hanhanW commented 9 months ago

note: the flatten is needed for LHS packing as well

hanhanW commented 9 months ago

I will use below three cases to drive the optimization work.

func.func @pack_i8(%source: tensor<?x?xi8>) -> tensor<?x?x16x2xi8> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %d0 = tensor.dim %source, %c0 : tensor<?x?xi8>
  %d1 = tensor.dim %source, %c1 : tensor<?x?xi8>
  %c16 = arith.constant 16 : index
  %c2 = arith.constant 2 : index
  %tiled_d0 = arith.ceildivui %d0, %c2 : index
  %tiled_d1 = arith.ceildivui %d1, %c16 : index

  %zero = arith.constant 0 : i8
  %init_pack = tensor.empty(%tiled_d1, %tiled_d0) : tensor<?x?x16x2xi8>
  %pack = tensor.pack %source
    padding_value(%zero: i8)
    outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 2]
    into %init_pack : tensor<?x?xi8> -> tensor<?x?x16x2xi8>
  return %pack : tensor<?x?x16x2xi8>
}

func.func @pack_bf16(%source: tensor<?x?xbf16>) -> tensor<?x?x16x2xbf16> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %d0 = tensor.dim %source, %c0 : tensor<?x?xbf16>
  %d1 = tensor.dim %source, %c1 : tensor<?x?xbf16>
  %c16 = arith.constant 16 : index
  %c2 = arith.constant 2 : index
  %tiled_d0 = arith.ceildivui %d0, %c2 : index
  %tiled_d1 = arith.ceildivui %d1, %c16 : index

  %zero = arith.constant 0.000000e+00 : bf16
  %init_pack = tensor.empty(%tiled_d1, %tiled_d0) : tensor<?x?x16x2xbf16>
  %pack = tensor.pack %source
    padding_value(%zero: bf16)
    outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 2]
    into %init_pack : tensor<?x?xbf16> -> tensor<?x?x16x2xbf16>
  return %pack : tensor<?x?x16x2xbf16>
}

// i4 can not be inputs and outputs types
func.func @pack_i4(%source: tensor<?x?x?xi8>) -> tensor<?x?x?x32x8xi8> {
  %source_i4 = arith.trunci %source : tensor<?x?x?xi8> to tensor<?x?x?xi4>
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c2 = arith.constant 2 : index
  %d0 = tensor.dim %source_i4, %c0 : tensor<?x?x?xi4>
  %d1 = tensor.dim %source_i4, %c1 : tensor<?x?x?xi4>
  %d2 = tensor.dim %source_i4, %c2 : tensor<?x?x?xi4>
  %c32 = arith.constant 32 : index
  %c8 = arith.constant 8 : index
  %tiled_d0 = arith.ceildivui %d0, %c32 : index
  %tiled_d2 = arith.ceildivui %d2, %c8 : index

  %zero = arith.constant 0 : i4
  %init_pack = tensor.empty(%d1, %tiled_d0, %tiled_d2) : tensor<?x?x?x32x8xi4>
  %pack = tensor.pack %source_i4
    padding_value(%zero: i4)
    outer_dims_perm = [1, 0, 2] inner_dims_pos = [0, 2] inner_tiles = [32, 8]
    into %init_pack : tensor<?x?x?xi4> -> tensor<?x?x?x32x8xi4>

  %res = arith.extsi %pack : tensor<?x?x?x32x8xi4> to tensor<?x?x?x32x8xi8>
  return %res : tensor<?x?x?x32x8xi8>
}
dcaballe commented 9 months ago

Didn't we already have a pass to make the innermost dimension larger?

hanhanW commented 9 months ago

Didn't we already have a pass to make the innermost dimension larger?

Yes, we have. The patterns will make the innermost dimension as larger as possible, i.e., it flattens it to a big 1-D vector. Flattening them to 1D vectors seems to have huge compilation time issue (https://github.com/openxla/iree/pull/16239). I will need some time to investigate it, and I think we want some control here.

dcaballe commented 9 months ago

That's surprising as we should effectively be unrolling less... It's only one model so maybe there's a collateral effect happening...

hanhanW commented 5 months ago

https://github.com/iree-org/iree/pull/16456 should address the issue. I'll revisit how to land the PR