flann-lib / flann

Fast Library for Approximate Nearest Neighbors
http://people.cs.ubc.ca/~mariusm/flann
Other
2.22k stars 647 forks source link

Optimize L2 norm computation with optional vectorization. #454

Open legrosbuffle opened 4 years ago

legrosbuffle commented 4 years ago

Right now only float is vectorized, other specializations will be added in a subsequent commit.

To benchmark the query, I've modified flann_example_cpp to run 1000 query loops instead of just one.

Before this change, computing the norm is 36.4% of the execution. Even though the loop is unrolled by a factor 4, loads, additions and multiplications are still scalar. After the change, the loop is vectorized. When the max distance is given, we still have to reduce at every iteration to compare. Else, we only need a single reduce at the end. In the former case (worst_dist >= 0), computing the norm becomes 29.8% of execution. Execution time drops from 35.1 to 31.8s (10% improvement). In the latter case (worst_dist < 0), computing the norm becomes 24.2% of execution. Execution time drops from 35.1 to 28.2 (20% improvement).

Before:

ROUTINE ======================== flann::L2::operator()
47564889135 47564889135 (flat, cum) 36.44% of Total
  35092264   35092264      1f830: lea    (%rsi,%rcx,4),%rcx               ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:154
 328058643  328058643      1f834: movaps %xmm0,%xmm5                      ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:150
  20314866   20314866      1f837: lea    -0xc(%rcx),%rax                  ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:155
    940905     940905      1f83b: cmp    %rax,%rsi                        ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:158
         .          .      1f83e: jae    1f8e1 <float flann::L2<float>::operator()<float*, float const*>(float*, float const*, unsigned long, float) const+0xb1>
  18337652   18337652      1f844: pxor   %xmm6,%xmm6                      ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:152
 241505735  241505735      1f848: movaps %xmm6,%xmm0
   8351632    8351632      1f84b: nopl   0x0(%rax,%rax,1)
2483873273 2483873273      1f850: movss  (%rsi),%xmm1                     ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:159
 220146242  220146242      1f854: movss  0x4(%rsi),%xmm4                  ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:160
   3650929    3650929      1f859: add    $0x10,%rdx                       ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:165
3336460614 3336460614      1f85d: add    $0x10,%rsi                       ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:164
 393866511  393866511      1f861: subss  -0x10(%rdx),%xmm1                ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:159
 366575932  366575932      1f866: subss  -0xc(%rdx),%xmm4                 ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:160
   3634832    3634832      1f86b: movss  -0x8(%rsi),%xmm3                 ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:161
9095107561 9095107561      1f870: subss  -0x8(%rdx),%xmm3
 256805437  256805437      1f875: movss  -0x4(%rsi),%xmm2                 ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:162
 297745047  297745047      1f87a: subss  -0x4(%rdx),%xmm2
   3712756    3712756      1f87f: comiss %xmm6,%xmm5                      ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:167
7525622349 7525622349      1f882: mulss  %xmm1,%xmm1                      ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:163
 617814475  617814475      1f886: mulss  %xmm4,%xmm4
 111702640  111702640      1f88a: mulss  %xmm3,%xmm3
 261873539  261873539      1f88e: mulss  %xmm2,%xmm2
7422685724 7422685724      1f892: addss  %xmm4,%xmm1
 970925099  970925099      1f896: addss  %xmm3,%xmm1
5221453445 5221453445      1f89a: addss  %xmm2,%xmm1
4987699692 4987699692      1f89e: addss  %xmm1,%xmm0
                                                                          ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:167
1526384608 1526384608      1f8a2: jbe    1f8a9 <float flann::L2<float>::operator()<float*, float const*>(float*, float const*, unsigned long, float) const+0x79>
         .          .      1f8a4: comiss %xmm5,%xmm0
         .          .      1f8a7: ja     1f8d8 <float flann::L2<float>::operator()<float*, float const*>(float*, float const*, unsigned long, float) const+0xa8>
         .          .      1f8a9: cmp    %rsi,%rax                        ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:158
1558693597 1558693597      1f8ac: ja     1f850 <float flann::L2<float>::operator()<float*, float const*>(float*, float const*, unsigned long, float) const+0x20>
         .          .      1f8ae: cmp    %rsi,%rcx                        ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:172
   8300423    8300423      1f8b1: jbe    1f8e0 <float flann::L2<float>::operator()<float*, float const*>(float*, float const*, unsigned long, float) const+0xb0>
         .          .      1f8b3: add    $0x4,%rdx                        ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:173
         .          .      1f8b7: movss  (%rsi),%xmm1
         .          .      1f8bb: add    $0x4,%rsi
         .          .      1f8bf: subss  -0x4(%rdx),%xmm1
         .          .      1f8c4: mulss  %xmm1,%xmm1                      ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:174
         .          .      1f8c8: addss  %xmm1,%xmm0
         .          .      1f8cc: cmp    %rsi,%rcx                        ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:172
         .          .      1f8cf: ja     1f8b3 <float flann::L2<float>::operator()<float*, float const*>(float*, float const*, unsigned long, float) const+0x83>
         .          .      1f8d1: retq                                    ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:177
         .          .      1f8d2: nopw   0x0(%rax,%rax,1)
         .          .      1f8d8: retq
         .          .      1f8d9: nopl   0x0(%rax)
 237552713  237552713      1f8e0: retq
         .          .      1f8e1: pxor   %xmm0,%xmm0                      ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:152
         .          .      1f8e5: jmp    1f8ae <float flann::L2<float>::operator()<float*, float const*>(float*, float const*, unsigned long, float) const+0x7e>
         .          .      1f8e7: nopw   0x0(%rax,%rax,1)

After (worst_dist >= 0):

ROUTINE ======================== flann::L2::Compute
34754208773 34754208773 (flat, cum) 29.84% of Total
  48641080   48641080      13d20: pxor   %xmm3,%xmm3                      ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf dist.h:191
 333401185  333401185      13d24: lea    (%rdi,%rdx,4),%r8                ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf dist.h:189
  11949349   11949349      13d28: comiss %xmm3,%xmm0                      ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf dist.h:191
   1843594    1843594      13d2b: lea    -0xc(%r8),%rcx
  26596752   26596752      13d2f: ja     13d75 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0x55>
 263415728  263415728      13d31: jmp    13da8 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0x88>

                           [...non-taken branch...]

  31882241   31882241      13d99: cmp    %rdi,%r8                         ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf dist.h:200
                                                                          ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf
         .          .      13d9c: ja     13d80 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0x60>
  34863218   34863218      13d9e: movaps %xmm3,%xmm0                      ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf dist.h:205
 157914789  157914789      13da1: retq
         .          .      13da2: nopw   0x0(%rax,%rax,1)                 ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf
   7396959    7396959      13da8: mov    %rsi,%rdx                        ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rff dist.h:196
   3658216    3658216      13dab: mov    %rdi,%rax
  45037071   45037071      13dae: cmp    %rcx,%rdi
                                                                          ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rff
         .          .      13db1: jae    13d99 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0x79>
 294492579  294492579      13db3: nopl   0x0(%rax,%rax,1)                 ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rff dist.h:196
2729755072 2729755072      13db8: movups (%rax),%xmm0                     ;_Z10_mm_sub_psDv4_fS_ dist.h:196
1834682910 1834682910      13dbb: movups (%rdx),%xmm5
 645605525  645605525      13dbe: add    $0x10,%rax                       ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rff dist.h:196
 903030920  903030920      13dc2: add    $0x10,%rdx
2299620094 2299620094      13dc6: subps  %xmm5,%xmm0                      ;_Z10_mm_sub_psDv4_fS_ dist.h:196
2684229688 2684229688      13dc9: mulps  %xmm0,%xmm0                      ;_Z10_mm_mul_psDv4_fS_ dist.h:196
3352147473 3352147473      13dcc: movaps %xmm0,%xmm2                      ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rff dist.h:196
1097632435 1097632435      13dcf: movaps %xmm0,%xmm1
1530297598 1530297598      13dd2: unpckhps %xmm0,%xmm2
1760337677 1760337677      13dd5: shufps $0xff,%xmm0,%xmm1
3750657802 3750657802      13dd9: addss  %xmm2,%xmm1
 769888830  769888830      13ddd: movaps %xmm0,%xmm2
 956681123  956681123      13de0: shufps $0x55,%xmm0,%xmm2
2762013893 2762013893      13de4: addss  %xmm2,%xmm1
3084482538 3084482538      13de8: addss  %xmm1,%xmm0
1995863847 1995863847      13dec: addss  %xmm0,%xmm3
   3697571    3697571      13df0: cmp    %rcx,%rax
 991839621  991839621      13df3: jb     13db8 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0x98>
  21978923   21978923      13df5: mov    %r8,%rax
  27542392   27542392      13df8: sub    %rdi,%rax
  36804713   36804713      13dfb: sub    $0xd,%rax
   5487291    5487291      13dff: and    $0xfffffffffffffff0,%rax
  32055141   32055141      13e03: add    $0x10,%rax
  32198989   32198989      13e07: add    %rax,%rdi
 176303775  176303775      13e0a: add    %rax,%rsi
   8280171    8280171      13e0d: jmp    13d99 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0x79>

After (worst_dist < 0):

ROUTINE ======================== flann::L2::Compute
25758491778 25758491778 (flat, cum) 24.17% of Total
 240673135  240673135      13d20: pxor   %xmm3,%xmm3                      ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf dist.h:191
 137020253  137020253      13d24: lea    (%rdi,%rdx,4),%r8                ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf dist.h:189
  11939237   11939237      13d28: comiss %xmm3,%xmm0                      ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf dist.h:191
  58564715   58564715      13d2b: lea    -0xc(%r8),%rcx
 111559377  111559377      13d2f: ja     13d75 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0x55>
 118986257  118986257      13d31: jmp    13da8 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0x88>

                           [...non-taken branch...]

 70503557   70503557      13da8: cmp    %rcx,%rdi                        ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf dist.h:196
                                                                          ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf
         .          .      13dab: jae    13d99 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0x79>
    876666     876666      13dad: mov    %rsi,%rdx                        ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf dist.h:196
 124448567  124448567      13db0: mov    %rdi,%rax
 136135059  136135059      13db3: pxor   %xmm1,%xmm1
  79507393   79507393      13db7: nopw   0x0(%rax,%rax,1)
2822080992 2822080992      13dc0: movups (%rax),%xmm0                     ;_Z10_mm_sub_psDv4_fS_ dist.h:196
  31839811   31839811      13dc3: movups (%rdx),%xmm6
 363014116  363014116      13dc6: add    $0x10,%rax                       ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf dist.h:196
  22856826   22856826      13dca: add    $0x10,%rdx
1031249812 1031249812      13dce: subps  %xmm6,%xmm0                      ;_Z10_mm_sub_psDv4_fS_ dist.h:196
1682412937 1682412937      13dd1: mulps  %xmm0,%xmm0                      ;_Z10_mm_mul_psDv4_fS_ dist.h:196
17546296134 17546296134      13dd4: addps  %xmm0,%xmm1                      ;_Z10_mm_add_psDv4_fS_ dist.h:196
   1780455    1780455      13dd7: cmp    %rcx,%rax                        ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf dist.h:196
  36550095   36550095      13dda: jb     13dc0 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0xa0>
    891637     891637      13ddc: movaps %xmm1,%xmm3
  24502724   24502724      13ddf: movaps %xmm1,%xmm0
    905376     905376      13de2: mov    %r8,%rax
  12790369   12790369      13de5: shufps $0x55,%xmm1,%xmm0
 342813117  342813117      13de9: addss  %xmm0,%xmm3
    918634     918634      13ded: movaps %xmm1,%xmm0
         .          .      13df0: sub    %rdi,%rax                        ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf
         .          .      13df3: unpckhps %xmm1,%xmm0
   2675031    2675031      13df6: sub    $0xd,%rax                        ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf dist.h:196
         .          .      13dfa: shufps $0xff,%xmm1,%xmm1                ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf
         .          .      13dfe: and    $0xfffffffffffffff0,%rax
 449641847  449641847      13e02: addss  %xmm0,%xmm3                      ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf dist.h:196
   2732463    2732463      13e06: add    $0x10,%rax
         .          .      13e0a: add    %rax,%rdi                        ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf
         .          .      13e0d: add    %rax,%rsi
 290497884  290497884      13e10: addss  %xmm1,%xmm3                      ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf dist.h:196
                                                                          ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf dist.h:200
   1827302    1827302      13e14: jmp    13d99 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0x79>
         .          .      13e16: nopw   %cs:0x0(%rax,%rax,1)             ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf.constprop.0