PaddlePaddle / Paddle-Lite

PaddlePaddle High Performance Deep Learning Inference Engine for Mobile and Edge (飞桨高性能深度学习端侧推理引擎)
https://www.paddlepaddle.org.cn/lite
Apache License 2.0
6.96k stars 1.61k forks source link

gemm_int8_compute_test 测试出现段错误,段错误位置在汇编指令部分 #10355

Open xuwenlong02 opened 1 year ago

xuwenlong02 commented 1 year ago
void packb_int8(int8_t* out,
                const int8_t* in,
                const int ldin,
                const int k0,
                const int kmax,
                const int n0,
                const int nmax,
                const int8_t* zerobuf) {
                  LOG(INFO) << "packb_int8 " ;
  const int8_t* inptr = in + k0 * ldin + n0;
  const uint8_t mask_buffer[16] = {
      0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
  int x_len = nmax - n0;
  int y_len = kmax - k0;
  int kup = ROUNDUP(y_len, KBLOCK_INT8);
  int kcnt = x_len / NBLOCK_INT8_OTH;
  int rem = x_len & (NBLOCK_INT8_OTH - 1);
  int stride_out = NBLOCK_INT8_OTH * kup;

  int8x16_t vzero = vdupq_n_s8(0);
  uint8x16_t vmask = vcltq_u8(vld1q_u8(mask_buffer), vdupq_n_u8(rem));
  LITE_PARALLEL_COMMON_BEGIN(y, tid, y_len, 0, KBLOCK_INT8) {
    const int8_t* ptr0 = inptr + y * ldin;
    const int8_t* ptr1 = ptr0 + ldin;
    const int8_t* ptr2 = ptr1 + ldin;
    const int8_t* ptr3 = ptr2 + ldin;
    if (y + KBLOCK_INT8 > y_len) {
      switch (y + KBLOCK_INT8 - y_len) {
        case 3:
          ptr1 = zerobuf;
        case 2:
          ptr2 = zerobuf;
        case 1:
          ptr3 = zerobuf;
        default:
          break;
      }
    }
    int8_t* outptr_row_col = out + y * NBLOCK_INT8_OTH;
    int k = kcnt;

    LOG(INFO) << "packb_int8 asm " << kcnt ;

// clang-format off
#ifdef __aarch64__
    asm volatile(
    "ld1    {v0.16b},   [%[ptr0]],  #16\n" /* load r0 */
    "ld1    {v1.16b},   [%[ptr1]],  #16\n" /* load r1 */
    "ld1    {v2.16b},   [%[ptr2]],  #16\n" /* load r2 */
    "ld1    {v3.16b},   [%[ptr3]],  #16\n" /* load r3 */
    "cbz    %w[k], 1f\n"                   /* jump to remain */
    "0:\n"                                 /* main loop */
    /* trans 16b */
    "trn1   v4.16b, v0.16b, v1.16b\n" /* get a0,b0, a2,b2, a4,b4, a6,b6,
                                         a8,b8, a10,b10, a12,b12, a14,b14 */
    "trn2   v5.16b, v0.16b, v1.16b\n" /* get a1,b1, a3,b3, a5,b5, a7,b7,
                                         a9,b9, a11,b11, a13,b13, a15,b15 */
    "trn1   v6.16b, v2.16b, v3.16b\n" /* get c0,d0, c2,d2, c4,d4, c6,d6,
                                         c8,d8, c10,d10, c12,d12, c14,d14 */
    "trn2   v7.16b, v2.16b, v3.16b\n" /* get c1,d1, c3,d3, c5,d5, c7,d7,
                                         c9,d9, c11,d11, c13,d13, c15,d15 */
    "ld1    {v0.16b},   [%[ptr0]],  #16\n" /* load r0 */
    "ld1    {v1.16b},   [%[ptr1]],  #16\n" /* load r1 */
    "subs   %w[k], %w[k], #1\n"            /* loop cnt -1 */
    /* trans 8h */
    "trn1   v8.8h, v4.8h, v5.8h\n" /* get a0,b0, a1,b1, a4,b4, a5,b5, a8,b8,
                                      a9,b9, a12,b12, a13,b13 */
    "trn2   v9.8h, v4.8h, v5.8h\n" /* get a2,b2, a3,b3, a6,b6, a7,b7,
                                      a10,b10, a11,b11, a14,b14, a15,b15 */
    "trn1   v10.8h, v6.8h, v7.8h\n" /* get c0,d0, c1,d1, c4,d4, c5,d5,
                                       c8,d8, c9,d9, c12,d12, c13,d13 */
    "trn2   v11.8h, v6.8h, v7.8h\n" /* get c2,d2, c3,d3, c6,d6, c7,d7,
                                       c10,d10, c11,d11, c14,d14, c15,d15 */
    /* trans 4s */
    "ld1    {v2.16b},   [%[ptr2]],  #16\n" /* load r2 */
    "trn1   v4.4s, v8.4s, v9.4s\n" /* get a0,b0, a1,b1, a2,b2, a3,b3, a8,b8,
                                      a9,b9, a10,b10, a11,b11 */
    "trn2   v5.4s, v8.4s, v9.4s\n" /* get a4,b4, a5,b5, a6,b6, a7,b7,
                                      a12,b12, a13,b13, a14,b14, a15,b15 */
    "trn1   v6.4s, v10.4s, v11.4s\n" /* get c0,d0, c1,d1, c2,d2, c3,d3,
                                        c8,d8, c9,d9, c10,d10, c11,d11 */
    "trn2   v7.4s, v10.4s, v11.4s\n" /* get c4,d4, c5,d5, c6,d6, c7,d7,
                                        c12,d12, c13,d13, c14,d14, c15,d15
                                        */
    /* trans 2d */
    "ld1    {v3.16b},   [%[ptr3]],  #16\n" /* load r3 */
    "trn1   v8.2d, v4.2d, v6.2d\n" /* get a0,b0, a1,b1, a2,b2, a3,b3, c0,d0,
                                      c1,d1, c2,d2, c3,d3 */
    "trn2   v10.2d, v4.2d, v6.2d\n" /* get a8,b8, a9,b9, a10,b10, a11,b11,
                                       c8,d8, c9,d9, c10,d10, c11,d11 */
    "trn1   v9.2d, v5.2d, v7.2d\n" /* get a4,b4, a5,b5, a6,b6, a7,b7, c4,d4,
                                      c5,d5, c6,d6, c7,d7 */
    "trn2   v11.2d, v5.2d, v7.2d\n" /* get a12,b12, a13,b13, a14,b14,
                                       a15,b15, c12,d12, c13,d13, c14,d14,
                                       c15,d15 */
    "st1    {v8.16b, v9.16b, v10.16b, v11.16b},   [%[ptr_out]], %[stride]\n"
    "bgt    0b\n"          /* jump to main loop */
    "1:\n"                 /* process remain */
    "cbz    %w[rem], 2f\n" /* jump to remain */
    /* bit select */
    "bif    v0.16b, %[vzero].16b, %[mask].16b\n" /* pad 0 */
    "bif    v1.16b, %[vzero].16b, %[mask].16b\n" /* pad 0 */
    "bif    v2.16b, %[vzero].16b, %[mask].16b\n" /* pad 0 */
    "bif    v3.16b, %[vzero].16b, %[mask].16b\n" /* pad 0 */
    /* trans 16b */
    "trn1   v4.16b, v0.16b, v1.16b\n" /* get a0,b0, a2,b2, a4,b4, a6,b6,
                                         a8,b8, a10,b10, a12,b12, a14,b14 */
    "trn2   v5.16b, v0.16b, v1.16b\n" /* get a1,b1, a3,b3, a5,b5, a7,b7,
                                         a9,b9, a11,b11, a13,b13, a15,b15 */
    "trn1   v6.16b, v2.16b, v3.16b\n" /* get c0,d0, c2,d2, c4,d4, c6,d6,
                                         c8,d8, c10,d10, c12,d12, c14,d14 */
    "trn2   v7.16b, v2.16b, v3.16b\n" /* get c1,d1, c3,d3, c5,d5, c7,d7,
                                         c9,d9, c11,d11, c13,d13, c15,d15 */
    /* trans 8h */
    "trn1   v8.8h, v4.8h, v5.8h\n" /* get a0,b0, a1,b1, a4,b4, a5,b5, a8,b8,
                                      a9,b9, a12,b12, a13,b13 */
    "trn2   v9.8h, v4.8h, v5.8h\n" /* get a2,b2, a3,b3, a6,b6, a7,b7,
                                      a10,b10, a11,b11, a14,b14, a15,b15 */
    "trn1   v10.8h, v6.8h, v7.8h\n" /* get c0,d0, c1,d1, c4,d4, c5,d5,
                                       c8,d8, c9,d9, c12,d12, c13,d13 */
    "trn2   v11.8h, v6.8h, v7.8h\n" /* get c2,d2, c3,d3, c6,d6, c7,d7,
                                       c10,d10, c11,d11, c14,d14, c15,d15 */
    /* trans 4s */
    "trn1   v4.4s, v8.4s, v9.4s\n" /* get a0,b0, a1,b1, a2,b2, a3,b3, a8,b8,
                                      a9,b9, a10,b10, a11,b11 */
    "trn2   v5.4s, v8.4s, v9.4s\n" /* get a4,b4, a5,b5, a6,b6, a7,b7,
                                      a12,b12, a13,b13, a14,b14, a15,b15 */
    "trn1   v6.4s, v10.4s, v11.4s\n" /* get c0,d0, c1,d1, c2,d2, c3,d3,
                                        c8,d8, c9,d9, c10,d10, c11,d11 */
    "trn2   v7.4s, v10.4s, v11.4s\n" /* get c4,d4, c5,d5, c6,d6, c7,d7,
                                        c12,d12, c13,d13, c14,d14, c15,d15
xuwenlong02 commented 1 year ago

运行环境 rk3568 64位系统 ARCH: aarch64

xuwenlong02 commented 1 year ago

问题解决了 test_gemm_int8

xuwenlong02 commented 1 year ago

lite/tests/math/gemm_int8_compute_test.cc 多个地方缺少了#pragma omp parallel for

这个也不是治本的办法,多跑几次还是会挂

xuwenlong02 commented 1 year ago

问题已找到参考PR https://github.com/PaddlePaddle/Paddle-Lite/pull/10365

xuwenlong02 commented 1 year ago

这个问题好像跟系统内存有关。如果一直运行,跑五六次后就挂了,跑十几次之后,每次都挂。

xiebaiyuan commented 7 months ago

这个问题好像跟系统内存有关。如果一直运行,跑五六次后就挂了,跑十几次之后,每次都挂。

关闭omp其实更加合适.