Open polvalente opened 1 year ago
Although #1424 did move eigh to defn, it's still worth looking into using a new implementation for speed and accuracy
Although #1424 did move eigh to defn, it's still worth looking into using a new implementation for speed and accuracy
Do you have a basic test I could use to compare them?
@christianjgreen in short, you can just compare the execution time of jax.linalg.eigh
vs Nx.LinAlg.eigh
using a 200x200 f32 tensor using EXLA
as the default compiler and backend. You'll notice that Nx will barely handle it -- takes 42s on my machine -- while jax handles it just fine -- takes around 43ms on my machine.
SVD will consequently suffer with the same performance drop because SVD uses eigh under the hood.
@christianjgreen in short, you can just compare the execution time of
jax.linalg.eigh
vsNx.LinAlg.eigh
using a 200x200 f32 tensor usingEXLA
as the default compiler and backend. You'll notice that Nx will barely handle it -- takes 42s on my machine -- while jax handles it just fine -- takes around 43ms on my machine.SVD will consequently suffer with the same performance drop because SVD uses eigh under the hood.
Thanks for the info! I just finished a high level jacobi eigh method that is solving a 200x200 matrix in 3 seconds on my machine compared to ~70 seconds using qr. I thougt qr was generally faster at larger matrices, but I don't know much about the algorithm. I was going to turn the jacobi method into the QDWH-eigh method, but don't mind making a pr for it in the meantime.
Update: After adding some optimizations to the QR algorithm, I got it down to ~8 seconds which is still twice as slow as the jacobi method, which makes me think there is something else that can be optimized. What would be best for the library owners? Starting work on a QDWH-eigh with the jacobi method, or try to optimize the QR code so that it falls within its big O predictions?
Last update and sorry for all the pings! Even those QR-eigh decomposition is supposed to be much faster than Jacobi on large matrices, I can't seem to at least get it to match the performance of the jacobi method, which leads me to believe something is amiss with the way the QR algorithm gets compiled down. I've tried a few things like wilkinson shifts, deflating, and only checking subdiags but the iterations still grow too high before converging.
Current testing on my naïve implementation with the default 1000 iterations and an eps of 1.0e-4
takes 3.2s on my machine vs 88s with the current QR implementation. Adding wilkinson shifts and other optimizations can bring that down to around 10-30. but with not much accuracy.
I'll defer to @polvalente and @josevalim for next steps as I'm a complete newbie here.
Currently, we have a custom implementation for Nx.BinaryBackend and call the XLA implementation for eigh in EXLA. However, the XLA implementation seems to suffer from similar issues to the SVD one, in which it ends up being slower and with a different accuracy from the one Jax uses (https://github.com/google/jax/blob/main/jax/_src/lax/eigh.py).
Especially since we already have QDWH implemented in
Nx.LinAlg.SVD.qdwh
, it seems like a good idea to also reimplement eigh as a defn with optional+custom_grad (likeNx.LinAlg.svd
)