intel / xFasterTransformer

Apache License 2.0
351 stars 60 forks source link

[bug] Segmentation fault occurs at large batch sizes #140

Open aurora327 opened 9 months ago

aurora327 commented 9 months ago

Segmentation fault occurs at large batch sizes

  1. Command Line: ./run_benchmark.sh -m llama-7b -d bf16 -s 1 -bs 100 -in 512 -out 256 -i 1

    Functions with errors: onednn_amx_sgemm_f32bf16f32_compute_biasadd

    Matmul matrix shape: M = 51200, N = 12288, K= 4096, transA = 0,alpha=1.000000, lda=4096, beta=0.000000,ldc=12288

    _oneDNNverbose: onednn_verbose,info,oneDNN v3.2.0 (commit 04b180b9a58a78cf1a1cd2329671a5060c2be8de) onednn_verbose,info,cpu,runtime:OpenMP,nthr:48 onednn_verbose,info,cpu,isa:Intel AVX-512 with float16, Intel DL Boost and bfloat16 support and Intel AMX with bfloat16 and 8-bit integer support onednn_verbose,info,gpu,runtime:none onednn_verbose,info,prim_template:operation,engine,primitive,implementation,prop_kind,memory_descriptors,attributes,auxiliary,problem_desc,exec_time

  2. Command Line: ./run_benchmark.sh -m llama-7b -d bf16 -s 1 -bs 100 -in 32 -out 32 -i 1

    Functions with errors: hpj::Matrix &input, hpj::Matrix &output, hpj::Matrix &residential, bool isMaster) { TimeLine t("DownProj") assert(input.Rows() == output.Rows()); (ASSERT FAILED input.Cols()=22016, downWeight.Rows()=11008;)

    Matmul matrix shape: M = 3200, N = 12288, K= 4096, transA = 0,alpha=1.000000, lda=4096, beta=0.000000,ldc=12288

    Verbose: xft_verbose,exec,cpu,api,onednn_amx_sgemm_f32bf16f32_compute_biasadd,m3200n12288k4096,29.308059 xft_verbose,exec,cpu,api,onednn_amx_sgemm_f32bf16f32_compute_residential,m3200n4096k4096,12.953664 xft_verbose,exec,cpu,api,onednn_amx_sgemm_f32bf16f32_compute,m3200n22016k4096,42.813326

aurora327 commented 9 months ago

FOR CASE 1: ./benchdnn --matmul --dt=bf16:bf16:f32 --stag=ab --wtag=ab --dtag=ab --bia_dt=f32 5120 0x12288:12288x4096 0:PASSED __REPRO: --matmul --dt=bf16:bf16:f32 --stag=ab --wtag=ab --dtag=ab --bia_dt=f32 51200x12288:12288x4096 tests:1 passed:1 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:0 listed:0 total: 3.15s; fill: 1.34s (43%); compute_ref: 0.70s (22%); compare: 0.22s (7%);