intel / intel-extension-for-pytorch

A Python package for extending the official PyTorch that can easily obtain performance on Intel platform
Apache License 2.0
1.58k stars 244 forks source link

`torch.linalg.eigh` is significantly slower than expected on Max Series GPU #439

Open ogrisel opened 1 year ago

ogrisel commented 1 year ago

Describe the issue

Similarly to #428, I tried torch.linalg.eigh on a Max Series GPU using the Intel Devcloud and packages from the intel conda channel, the performance on XPU is not much better than on CPU:

>>> import intel_extension_for_pytorch
>>> import torch
>>> intel_extension_for_pytorch.__version__
'2.0.110+xpu'
>>> torch.__version__
'2.0.1a0+cxx11.abi'
>>> X = torch.randn(500, 500)
>>> X_xpu = X.to("xpu")
>>> %time C = X.T @ X
CPU times: user 938 ms, sys: 76.8 ms, total: 1.01 s
Wall time: 115 ms
>>> %time C_xpu = X_xpu.T @ X_xpu
CPU times: user 4.37 ms, sys: 4 µs, total: 4.37 ms
Wall time: 4.21 ms

So GEMM is around 20x faster on the XPU device that on the CPU host.

However, torch.linalg.eigh is not faster when using the XPU, which is quite unexpected given the speed difference for GEMM.

>>> %time _ = torch.linalg.eigh(C)
CPU times: user 2min 30s, sys: 10.2 s, total: 2min 40s
Wall time: 6.89 s
>>> %time _ = torch.linalg.eigh(C_xpu)
CPU times: user 4min 1s, sys: 14.5 s, total: 4min 15s
Wall time: 5.52 s

More information about the runtime environment of this session:

>>> from pprint import pprint
>>> pprint(dpctl.get_devices())
[<dpctl.SyclDevice [backend_type.opencl, device_type.cpu,  Intel(R) Xeon(R) Platinum 8480+] at 0x1472aac521f0>,
 <dpctl.SyclDevice [backend_type.opencl, device_type.accelerator,  Intel(R) FPGA Emulation Device] at 0x1472a80a9ef0>,
 <dpctl.SyclDevice [backend_type.level_zero, device_type.gpu,  Intel(R) Data Center GPU Max 1100] at 0x1472a80a9df0>]
>>> import joblib
>>> joblib.cpu_count(only_physical_cores=True)
112
>>> import threadpoolctl
>>> pprint(threadpoolctl.threadpool_info())
[{'filepath': '/home/u103854/mambaforge/envs/intel/lib/libmkl_rt.so.2',
  'internal_api': 'mkl',
  'num_threads': 112,
  'prefix': 'libmkl_rt',
  'threading_layer': 'intel',
  'user_api': 'blas',
  'version': '2023.2-Product'},
 {'filepath': '/home/u103854/mambaforge/envs/intel/lib/libiomp5.so',
  'internal_api': 'openmp',
  'num_threads': 112,
  'prefix': 'libiomp',
  'user_api': 'openmp',
  'version': None},
 {'filepath': '/home/u103854/mambaforge/envs/intel/lib/libgomp.so.1.0.0',
  'internal_api': 'openmp',
  'num_threads': 112,
  'prefix': 'libgomp',
  'user_api': 'openmp',
  'version': None}]

Furthermore, all those numbers are extremely slow for such a small dataset.

Here is the output of a similar experiment on my local laptop (Apple M1):

>>> import torch
>>> X = torch.randn(500, 500)
>>> %time C = X.T @ X
CPU times: user 247 µs, sys: 718 µs, total: 965 µs
Wall time: 4.5 ms
>>> %time _ = torch.linalg.eigh(C)
CPU times: user 12.3 ms, sys: 6.88 ms, total: 19.2 ms
Wall time: 20.6 ms
ogrisel commented 1 year ago

Here is the output of mamba list for this env:

``` # packages in environment at /home/u103854/mambaforge/envs/intel: # # Name Version Build Channel _libgcc_mutex 0.1 conda_forge intel _openmp_mutex 4.5 2_gnu intel asttokens 2.2.1 pyhd8ed1ab_0 conda-forge backcall 0.2.0 pyh9f0ad1d_0 conda-forge backports 1.0 pyhd8ed1ab_3 conda-forge backports.functools_lru_cache 1.6.5 pyhd8ed1ab_0 conda-forge brotli 1.0.9 h166bdaf_8 intel brotli-bin 1.0.9 h166bdaf_8 intel bzip2 1.0.8 hb9a14ef_9 intel ca-certificates 2023.7.22 hbcca054_0 conda-forge certifi 2023.7.22 pyhd8ed1ab_0 conda-forge charset-normalizer 3.1.0 pyhd8ed1ab_0 intel daal4py 2023.2.1 py310_intel_32 intel dal 2023.2.1 intel_32 intel decorator 5.0.9 pyhd3eb1b0_0 intel dpcpp-cpp-rt 2023.2.0 intel_49495 intel dpcpp_cpp_rt 2023.2.0 intel_49495 intel dpctl 0.14.5 py310he78b74f_24 intel dpnp 0.12.1 pypi_0 pypi executing 1.2.0 pyhd8ed1ab_0 conda-forge filelock 3.6.0 pyhd3eb1b0_0 intel fortran_rt 2023.2.0 intel_49495 intel icc_rt 2023.2.0 intel_49495 intel idna 3.4 pyhd8ed1ab_0 intel impi_rt 2021.10.0 intel_49371 intel intel-cmplr-lib-rt 2023.2.0 intel_49495 intel intel-cmplr-lic-rt 2023.2.0 intel_49495 intel intel-extension-for-pytorch 2.0.110 py310_xpu_0 intel intel-fortran-rt 2023.2.0 intel_49495 intel intel-opencl-rt 2023.2.0 intel_49495 intel intel-openmp 2023.2.0 intel_49495 intel intelpython 2023.2.0 0 intel ipython 8.14.0 pyh41d4057_0 conda-forge jedi 0.19.0 pyhd8ed1ab_0 conda-forge jinja2 3.0.1 pyhd3eb1b0_0 intel joblib 1.2.0 pyh3f38642_0 intel lark-parser 0.9.0 pyh9f0ad1d_0 intel level-zero 1.11.0 h00ab1b0_0 intel libbrotlicommon 1.0.9 h166bdaf_8 intel libbrotlidec 1.0.9 h166bdaf_8 intel libbrotlienc 1.0.9 h166bdaf_8 intel libffi 3.4.2 h7f98852_5 intel libgcc-ng 12.2.0 h65d4601_19 intel libgomp 12.2.0 h65d4601_19 intel libnsl 2.0.0 h7f98852_0 intel libsqlite 3.42.0 h2797004_0 intel libstdcxx-ng 12.2.0 h46fd767_19 intel libuuid 2.38.1 h0b41bf4_0 intel libuv 1.40.0 h7b6447c_2 intel libzlib 1.2.13 hd590300_5 intel markupsafe 2.1.3 py310h2372a71_0 conda-forge matplotlib-inline 0.1.6 pyhd8ed1ab_0 conda-forge mkl 2023.2.0 intel_49495 intel mkl-dpcpp 2023.2.0 intel_49495 intel mkl-service 2.4.0 py310hae59892_35 intel mkl_fft 1.3.6 py310h173b8ae_56 intel mkl_random 1.2.2 py310h1595b48_76 intel mkl_umath 0.1.1 py310hd987cd3_86 intel mpi4py 3.1.4 py310h618b5fa_0 intel mpmath 1.3.0 pyhd8ed1ab_0 conda-forge ncurses 6.4 hcb278e6_0 intel networkx 2.6.2 pyhd3eb1b0_2 intel numpy 1.24.3 py310hed7eef7_0 intel numpy-base 1.24.3 py310he88ecf9_0 intel openssl 3.1.3 hd590300_0 conda-forge packaging 23.1 pyhd8ed1ab_0 intel parso 0.8.3 pyhd8ed1ab_0 conda-forge pexpect 4.8.0 pyh1a96a4e_2 conda-forge pickleshare 0.7.5 py_1003 conda-forge pip 23.1.2 pyhd8ed1ab_0 intel platformdirs 3.6.0 pyhd8ed1ab_0 intel pooch 1.7.0 pyha770c72_3 intel prompt-toolkit 3.0.39 pyha770c72_0 conda-forge prompt_toolkit 3.0.39 hd8ed1ab_0 conda-forge psutil 5.9.5 py310h1fa729e_0 conda-forge ptyprocess 0.7.0 pyhd3deb0d_0 conda-forge pure_eval 0.2.2 pyhd8ed1ab_0 conda-forge pygments 2.16.1 pyhd8ed1ab_0 conda-forge pysocks 1.7.1 pyha2e5f31_6 intel python 3.10.12 hef7c979_1 intel python_abi 3.10 2_cp310 intel pytorch 2.0.1 py310_xpu_0 intel readline 8.2 h8228510_1 intel requests 2.31.0 pyhd8ed1ab_0 intel scikit-learn 1.2.2 py310hf7d194e_2 intel scikit-learn-intelex 2023.2.1 py310_intel_32 intel scipy 1.10.1 py310h01e2e1b_0 intel setuptools 67.7.2 pyhd8ed1ab_0 intel six 1.16.0 pyhd3eb1b0_1 intel stack_data 0.6.2 pyhd8ed1ab_0 conda-forge sympy 1.12 pyh04b8f61_3 conda-forge tbb 2021.10.0 intel_49541 intel tbb4py 2021.10.0 py310_intel_49541 intel threadpoolctl 3.1.0 pyh8a188c0_0 intel tk 8.6.12 h1ccaba5_0 intel traitlets 5.9.0 pyhd8ed1ab_0 conda-forge typing-extensions 4.6.3 hd8ed1ab_0 intel typing_extensions 4.6.3 pyha770c72_0 intel tzdata 2023c h71feb2d_0 intel urllib3 2.0.3 pyhd8ed1ab_0 intel wcwidth 0.2.6 pyhd8ed1ab_0 conda-forge wheel 0.40.0 pyhd8ed1ab_0 intel xz 5.2.8 h5eee18b_0 intel zlib 1.2.13 hd590300_5 intel ```
kta-intel commented 1 year ago

Thank you for reporting this, we are investigating the issue

jingxu10 commented 1 year ago

@gujinghui @tye1