The current implementation of eigh using a 10x10 symmetric matrix takes about 450ms for a 20x20 matrix and 154s for a 100x100 matrix, while the new implementation takes 0.5ms and 11ms respectively.
This is a defn version of the method used by XLA: https://github.com/openxla/xla/blob/main/xla/service/eigh_expander.cc
There is still a todo list and code cleanup/drying to do, but I wanted to pitch this before getting to far into the process. While this method has a static submatrix size with no recursion, this approach can be built on to recreate the recursive blocked-eigh used by JAX. This approach had less complexity and seemed like a nice way to make eigh performant without having to exactly copy the JAX method.
The gist of the method is to break the matrix into four submatrices and apply the jacobi rotations across all rows and cols each iteration and then joining the results.
Draft commit to introduce the idea.
Todo:
Inherit parent tensor type
Pass in eps as argument
Handle complex numbers (Should be somewhat easy, same implementation as XLA)
Reject malformed matrices
Refactor to be less ugly and cleanup syntax
Current issues:
The values returned are not always normalized the same wave the current implementation does, so some tests fail
https://github.com/elixir-nx/nx/issues/1027
The current implementation of
eigh
using a 10x10 symmetric matrix takes about450ms
for a 20x20 matrix and154s
for a 100x100 matrix, while the new implementation takes0.5ms
and11ms
respectively.This is a
defn
version of the method used by XLA: https://github.com/openxla/xla/blob/main/xla/service/eigh_expander.cc There is still a todo list and code cleanup/drying to do, but I wanted to pitch this before getting to far into the process. While this method has a static submatrix size with no recursion, this approach can be built on to recreate the recursive blocked-eigh used by JAX. This approach had less complexity and seemed like a nice way to makeeigh
performant without having to exactly copy the JAX method.The gist of the method is to break the matrix into four submatrices and apply the jacobi rotations across all rows and cols each iteration and then joining the results.
Draft commit to introduce the idea. Todo:
Current issues:
Please let me know if this is of any use! <3