Closed SamDuffield closed 3 weeks ago
I'm now not sure this is possible without using matrix-matrix multiplications as documented very clearly in Appendix A here https://arxiv.org/abs/2208.04933
Ooo actually, maybe we can because we have the eigendeomposition giving something like
$$ A_i A_j = U \exp(- D \delta_i + \delta_j) U^{-1} $$
So can avoid the matrix-matrix multiplications
Note to change the handling of random seeds, such that associative_scan=False
and associative_scan=True
handle them in the same way
We can get O(log(T)) time assuming O(T) cores for
sample
andlog_prob
by usingjax.lax.associative_scan
.Relevant tensorflow code here