oneapi-src / oneDNN

oneAPI Deep Neural Network Library (oneDNN)
https://uxlfoundation.org
Apache License 2.0
3.58k stars 983 forks source link

Bad perf for matmul of ba tensors #1667

Open WilliamTambellini opened 1 year ago

WilliamTambellini commented 1 year ago

Hello 1dnn team, Just asking if it is really expected for ba matmul to be so slow: eg:

M=63448
K=640
N=2
tag time
ab  18
ba  1790

With benchdnn :

wtambellini@lawtambe3 onednn-3.0/bin (master) $ ONEDNN_VERBOSE=1 OMP_NUM_THREADS=1 ./benchdnn --mode=P --matmul --dt=f32 --stag=ba --wtag=ba --dtag=ba 63448x640:640x2 
onednn_verbose,info,oneDNN v3.0.0 (commit 030eae4fe332eee75f10e05da4e8d7981c1a94b8)
onednn_verbose,info,cpu,runtime:OpenMP,nthr:1
onednn_verbose,info,cpu,isa:Intel AVX2
onednn_verbose,info,gpu,runtime:none
onednn_verbose,info,prim_template:operation,engine,primitive,implementation,prop_kind,memory_descriptors,attributes,auxiliary,problem_desc,exec_time
...

Best WT

dzarukin commented 1 year ago

Hi @WilliamTambellini, benchdnn command is ill-formed which it gives a warning towards the end of the execution. To get the desired problem, descriptor should be put at the very end of the line, not before desired tags. Thanks.

WilliamTambellini commented 1 year ago

Tks @dzarukin I ve fixed the benchdnn commandline call but still confirm the speed of ba matmul to be apparently bad:

$ ONEDNN_VERBOSE=1 OMP_NUM_THREADS=1 ./benchdnn --mode=P --matmul --dt=f32 --stag=ba --wtag=ba --dtag=ba 63448x640:640x2
onednn_verbose,info,oneDNN v3.0.0 (commit 030eae4fe332eee75f10e05da4e8d7981c1a94b8)
onednn_verbose,info,cpu,runtime:OpenMP,nthr:1
onednn_verbose,info,cpu,isa:Intel AVX2
onednn_verbose,info,gpu,runtime:none
onednn_verbose,info,prim_template:operation,engine,primitive,implementation,prop_kind,memory_descriptors,attributes,auxiliary,problem_desc,exec_time
onednn_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:ab:f0 dst_f32::blocked:ba:f0,,,63448x640,268.588
onednn_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:ab:f0 dst_f32::blocked:ba:f0,,,640x2,0.00195312
onednn_verbose,exec,cpu,matmul,ref:any,undef,src_f32::blocked:ba:f0 wei_f32::blocked:ba:f0 dst_f32::blocked:ba:f0,,,63448x640:640x2:63448x2,1626.54
onednn_verbose,exec,cpu,matmul,ref:any,undef,src_f32::blocked:ba:f0 wei_f32::blocked:ba:f0 dst_f32::blocked:ba:f0,,,63448x640:640x2:63448x2,1747.15

about 1700ms for ba matmul vs 16ms for ab matmul ?

dzarukin commented 1 year ago

That's caused by ba format on destination as it falls back to reference implementation. Optimized version doesn't support it, I suggest to refrain from using it.

WilliamTambellini commented 1 year ago

Tks @dzarukin New results (on Intel(R) Xeon(R) Platinum 8259CL):

src w   dst time
ab  ab  ab      20
ba      ba  ba      2400
ba      ba  ab      44
ba  ab  ab  45
ab  ba  ab  22

A warning that the perf severly depends on format tags would be appreciated, for instance over there: https://oneapi-src.github.io/oneDNN/v3.0/dev_guide_matmul.html CU

AngryLoki commented 11 months ago

Hello,

Pardon my ignorance, but isn't Aᵀ×Bᵀ == (B×A)ᵀ, therefore matrices in ba*ba->ba can be reinterpreted as ab*ab->ab, arguments switched (which does not suffer from the fallback to reference implementation)?