Open AlexanderMath opened 1 year ago
Initial implementation 76M cycles. Aim for 1M cycles or so. Code on this branch https://github.com/graphcore-research/pyscf-ipu/tree/hessenberg
Note: Algorithm is almost identical to tesselate_ipu.linalg.qr, it just multiplies with another H from the other side.
@balancap Do you have any pointers on hard parts?
@paolot-gc is looking at improving above profile. I'll take a look at https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.eigh_tridiagonal.html and https://github.com/tensorflow/tensorflow/blob/v2.13.0/tensorflow/python/ops/linalg/linalg_impl.py#L1232-L1588 during weekend.
Profile of a single iteration @balancap
@balancap @paolot-gc
Context: For M.shape=(1024,1024)
with M.T=M
we want eigh(M)
. We use the classic hessenberg(M)=tri_diagonal
to turn problem into eigh(tri_diagonal)
.
Problem: Literature claims eigvals(tri_diagonal)
are easy and eigvcects(tri_diagonal)
are hard (e.g. jax.lax.eigh_tridiagonal only supports eigvals and not eigvects).
Here's an algorithm (credit to fhvilshoj): Compute eigvals(tri_diagonal)
which are claimed to be easy. Replicate tri_diagonal
onto every tile and perform 1024 inverse power iterations in parallel on tri_diagonal-(eps+eighval[tile_i])*I
for eps~0
.
Correctness: Inverse power iteration converges to the eigenvector with smallest eigenvalue. The shift tri_diagonal-(eps+eighval[tile_i])*I
makes the tile_i
'th eigenvalue have size eps~0
.
Convergence: Since eigval(inv(A))=1/eigval(A)
we can make eigval(inv(tri_diagonal-(eps+eighval[tile_i])*I))~1/eps
arbitrarily large. The convergence of power iteration (on this inverse matrix) depends on the largest eigenvalue gap, which, we can make arbitrarily large as eps->0
. This is great in theory, I have no idea what happens in float32.
Memory: Since d=1024
we get tri_diagonal.nbytes~12.2kb
.
Time: We can compute inv(tri_diagonal-c*I)@v
with O(n) operations using gaussian elimination.
The above is called "Simultaneous Iteration" in https://courses.engr.illinois.edu/cs554/fa2015/notes/12_eigenvalue_8up.pdf. You can't make the eigenvalue gap arbitrarily large if lambdai = lambda{i+1}, so in practice you can't make it arbitrailty large if they are close to equal.
In general, I guess we should re-title this issue "Speed up eigh computation", and the first task is to gather potential implementation strategies, e.g. by just grabbing the slide headings from the lecture above.
As all of these approaches end up with provisos such as "Algorithm is complicated to implement and difficult questions of numerical stability, eigenvector orthogonality, and load balancing must be addressed", it's probably a good idea to see if existing code such as scalapack (or ARPACK for online-computed ERI) has been ported to e.g. numpy.
You can't make the eigenvalue gap arbitrarily large if lambdai = lambda{i+1}, so in practice you can't make it arbitrarily large if they are close to equal.
Agree.
The above is called "Simultaneous Iteration" in https://courses.engr.illinois.edu/cs554/fa2015/notes/12_eigenvalue_8up.pdf.
Do you mean "parallel inverse iteration"? Simultaneous iteration doesn't use inverse and it requires normalization (?ie.? orthogonalize the q simultaneous eigenvectors?).
Our current
ipu_eigh
uses the Jacobi algorithm. It is believed for other hw accelerators that the QR algorithm becomes faster than jacobi for largerd>=512
. Since we are targetingd>=512
we are considering implementing the QR algorithm.Main blocker: From Alex (?and Paul?) experience on CPU a naive QR algorithm empirically [1] needs roughly
~d^1.5
iterations to converge. This makes it hard for QR algorithm to compete with Jacobi, which from Alex experience [1] empirically converges in~d^0.5
iterations. Mature QR algorithm implementations reduce the number of iterations using shift/deflate tricks. Alex have never managed to get these to work. Difficulty could be alleviated if we found a working shift/deflate implementation under OS license we could port to IPU.Tasks:
O(d^4)
time. Mature implementations reduce toO(d^3)
time by first computing a Hessenberg decomposition.Notes:
[1] Using matrices
M=np.random.normal(0,1, (d,d)); M=(M+M.T)/2
. This may be a non-issue for other matrices.