normal-computing / thermox

Exact OU processes with JAX
Apache License 2.0
24 stars 3 forks source link

Add associative scan #14

Closed SamDuffield closed 3 weeks ago

SamDuffield commented 2 months ago

We can get O(log(T)) time assuming O(T) cores for sample and log_prob by using jax.lax.associative_scan.

Relevant tensorflow code here

SamDuffield commented 1 month ago

I'm now not sure this is possible without using matrix-matrix multiplications as documented very clearly in Appendix A here https://arxiv.org/abs/2208.04933

SamDuffield commented 1 month ago

Ooo actually, maybe we can because we have the eigendeomposition giving something like

$$ A_i A_j = U \exp(- D \delta_i + \delta_j) U^{-1} $$

So can avoid the matrix-matrix multiplications

KaelanDt commented 3 weeks ago

Note to change the handling of random seeds, such that associative_scan=False and associative_scan=True handle them in the same way

KaelanDt commented 3 weeks ago

Solved by https://github.com/normal-computing/thermox/pull/30