bytedeco / javacpp-presets

The missing Java distribution of native C++ libraries
Other
2.68k stars 744 forks source link

[PyTorch] Training is very slow on Linux. #1504

Closed haifengl closed 2 months ago

haifengl commented 6 months ago

Training 10 epochs of MNIST (the sample code from your project README) on takes > 500 seconds on Linux (24 cores, ubuntu 22.04). It takes only about 50 seconds on an old mac (4 cores). Both use CPU (no GPU or MPS).

saudet commented 6 months ago

Try to reduce the number of threads used by PyTorch to 6 or 12, see https://stackoverflow.com/questions/76084214/what-is-recommended-number-of-threads-for-pytorch-based-on-available-cpu-cores

HGuillemet commented 6 months ago

It's most probably related to pytorch not finding openblas and/or MKL in your path. Have you added mkl-platform-redist to your dependencies ? You can also try to download and use the official libtorch, add the path containing its libs to your library path, and set -Dorg.bytedeco.javacpp.pathsFirst: the official binaries are statically built with MKL.

haifengl commented 6 months ago

It helps a lot by set OMP_NUM_THREADS=12 on linux. The training speed is on par with mac (4 threads). Without it, torch.get_num_threads() returns 48. So the slowness may be caused by hyper-threading. According to your link, PyTorch will set the number of threads to the half of vCores. If so, we shouldn't have this issue on Linux. However, it is not the case with JavaCPP building. Do we miss some building configuration for Linux? Thanks!

saudet commented 6 months ago

So the default is 24 on that machine, but it doesn't mean it's going to give good results

haifengl commented 6 months ago

The default is 48 with JavaCPP build, which is too high. It should be 24 for this case.

HGuillemet commented 6 months ago

Have you tried with the official libtorch ?

haifengl commented 6 months ago

libtorch sets it to 24 by default on my box. And it works well. Why does JavaCPP build libtorch from source? Why not package the precompiled libtorch library from pytorch.org?

HGuillemet commented 6 months ago

See discussion here

HGuillemet commented 5 months ago

Here is the result of running the sample MNIST code on a machine with 32 vcores and 16 physical cores:

OpenMP lib Default num thread Speed
omp 32 Very slow
gomp 32 Somewhat slow
mkl static (official build) 16 fast

When forcing the num thread to 16 using OMP_NUM_THREADS or torch.set_num_threads, it's fast in all cases. I'll try to rationalize that in the PR so that torch is linked with gomp on linux. Also the fact that the presets preloads every possible openmp lib it finds, leading to possibly multiple different libraries loaded surely doesn't help.