elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.65k stars 193 forks source link

Re-implement Nx.LinAlg.eigh as defn #1027

Open polvalente opened 1 year ago

polvalente commented 1 year ago

Currently, we have a custom implementation for Nx.BinaryBackend and call the XLA implementation for eigh in EXLA. However, the XLA implementation seems to suffer from similar issues to the SVD one, in which it ends up being slower and with a different accuracy from the one Jax uses (https://github.com/google/jax/blob/main/jax/_src/lax/eigh.py).

Especially since we already have QDWH implemented in Nx.LinAlg.SVD.qdwh, it seems like a good idea to also reimplement eigh as a defn with optional+custom_grad (like Nx.LinAlg.svd)

polvalente commented 11 months ago

Although #1424 did move eigh to defn, it's still worth looking into using a new implementation for speed and accuracy

christianjgreen commented 5 months ago

Although #1424 did move eigh to defn, it's still worth looking into using a new implementation for speed and accuracy

Do you have a basic test I could use to compare them?

polvalente commented 5 months ago

@christianjgreen in short, you can just compare the execution time of jax.linalg.eigh vs Nx.LinAlg.eigh using a 200x200 f32 tensor using EXLA as the default compiler and backend. You'll notice that Nx will barely handle it -- takes 42s on my machine -- while jax handles it just fine -- takes around 43ms on my machine.

SVD will consequently suffer with the same performance drop because SVD uses eigh under the hood.

christianjgreen commented 5 months ago

@christianjgreen in short, you can just compare the execution time of jax.linalg.eigh vs Nx.LinAlg.eigh using a 200x200 f32 tensor using EXLA as the default compiler and backend. You'll notice that Nx will barely handle it -- takes 42s on my machine -- while jax handles it just fine -- takes around 43ms on my machine.

SVD will consequently suffer with the same performance drop because SVD uses eigh under the hood.

Thanks for the info! I just finished a high level jacobi eigh method that is solving a 200x200 matrix in 3 seconds on my machine compared to ~70 seconds using qr. I thougt qr was generally faster at larger matrices, but I don't know much about the algorithm. I was going to turn the jacobi method into the QDWH-eigh method, but don't mind making a pr for it in the meantime.

christianjgreen commented 5 months ago

Update: After adding some optimizations to the QR algorithm, I got it down to ~8 seconds which is still twice as slow as the jacobi method, which makes me think there is something else that can be optimized. What would be best for the library owners? Starting work on a QDWH-eigh with the jacobi method, or try to optimize the QR code so that it falls within its big O predictions?

christianjgreen commented 5 months ago

Last update and sorry for all the pings! Even those QR-eigh decomposition is supposed to be much faster than Jacobi on large matrices, I can't seem to at least get it to match the performance of the jacobi method, which leads me to believe something is amiss with the way the QR algorithm gets compiled down. I've tried a few things like wilkinson shifts, deflating, and only checking subdiags but the iterations still grow too high before converging.

Current testing on my naïve implementation with the default 1000 iterations and an eps of 1.0e-4 takes 3.2s on my machine vs 88s with the current QR implementation. Adding wilkinson shifts and other optimizations can bring that down to around 10-30. but with not much accuracy.

I'll defer to @polvalente and @josevalim for next steps as I'm a complete newbie here.