Open YouJiacheng opened 2 years ago
Ultimately JAX is at the mercy of the algorithms provided by cusolver here. For small matrices (smaller than 32x32), JAX currently uses the batched Jacobi solver that Nvidia provides. For larger matrices, JAX currently iterates over the batch elements sequentially, so you should expect no speedup from vmap.
There are a number of things one could try here.
One would be to try the batched Jacobi solver at larger sizes (https://github.com/google/jax/blob/28842151c6030a951bc389b771a3dcd3d4ca74a7/jaxlib/cusolver.py#L551 and its HLO-only cousin a few lines above), see also https://docs.nvidia.com/cuda/cusolver/index.html#cuSolverDN-lt-t-gt-syevjbatch Note this code is in jaxlib, although it's in the Python part of jaxlib so you can just locally edit your copy to play with it.
Another would be for jaxlib to solve multiple eigendecomposition problems in parallel on multiple CUDA streams. That would only be profitable if you aren't fully occupying GPU and CPU.
Thanks for speedy reply! IIUC, I can change the jaxlib python code without building jaxlib by myself, and let jaxlib use batched jacobi solver for large matrices as well.
Yes, you could just edit the (installed) copy of cusolver.py
to alter the threshold. Does it help?
It helps! 0.90s -> 0.55s. Thank you so much! (But it is still much slower than my expectation.)
And I wonder why CPU version of jax.lax.linalg.eigh
+ vmap
doesn't linear speedup comparing to single core scipy, it has >11x peak CPU usage.
You could send a PR altering the threshold, if you like, although we'd probably need to collect a wider range of timings at different sizes and batch sizes.
The CPU version also just calls a LAPACK function in a loop to handle batches. In fact, it's a LAPACK function we use provided by scipy, so I'd be surprised if you saw any speedup over scipy at all. That said, the algorithm does use parallelism internally at least for some of the phases. If we aren't getting enough parallelism, we could consider using multiple threads.
We don't have a batched eigh on CPU (as far as I am aware, no-one does on CPU, although some of the algorithms that work well when vectorized on GPU and TPU might work well on CPU also particularly for small matrix sizes, e.g., a vectorized Jacobi solver).
In fact, it's a LAPACK function we use provided by scipy, so I'd be surprised if you saw any speedup over scipy at all. That said, the algorithm does use parallelism internally at least for some of the phases.
JAX will use multiple core, while scipy only use single core. But JAX with multiple core only has a bit speed up, at the cost of preventing user manually using spmd/data parallel.
Can we have pytorch-like set_num_threads
and set_num_interop_threads
to control the parallel?
import torch
torch.set_num_threads(1)
torch.set_num_interop_threads(24)
@torch.jit.script
def mt_eigh(x: torch.Tensor):
futs = [torch.jit._fork(torch.linalg.eigh, x[i]) for i in range(24)]
return [torch.jit._wait(fut) for fut in futs]
I find that this(7.5s for 24*1024*320*320
) is 50x faster than JAX on 24-core CPU (15.6s for 1024*320*320
) and 40x faster than naively let pytorch use intra-op parallelism with 24 threads(12.4s for 1024*320*320
). --- which is actually 1.8x slower than single thread(6.9s for 1024*320*320
), 2.3x slower than 4 threads(5.4s for 1024*320*320
).
GPU: V100-PCIE 16G CPU: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz
jax.lax.linalg.eigh
on 1 GPU use 0.90s for 16 problems. on all CPU-core(top
report 2340% peak CPU usage) use 2.44s for 16 problems.scipy.linalg.eigh
on 1 CPU-core(top
report 200% peak CPU usage) use 0.21s for 1 problem. This result means that, GPU only have <4x throughput, and >11x CPU usage only have <1.4x throughput, while there should be a embarrassingly parallel givenvmap
.