normal-computing / thermox

Exact OU processes with JAX
Apache License 2.0
31 stars 6 forks source link

Add associative scan #30

Closed SamDuffield closed 3 months ago

SamDuffield commented 4 months ago

First attempt at using jax.lax.associative_scan #14 , but it's throwing a matmul contracting dimensions error and I'm not sure why.

SamDuffield commented 4 months ago

Update: associative_scan now working but seems like something is wrong with the calculations so I need to check the maths again

SamDuffield commented 4 months ago

Ok I fixed the maths! At the cost of doubling the number of expm_vp calls, we might be able to halve it again with further thought although I'm not sure.

Next step is to add associative_scan for log_prob

KaelanDt commented 3 months ago

Finished the speedup comparison with and without associative scan and adapted the handling of random keys in _sample_identity_diffusion, should be good to review

SamDuffield commented 3 months ago

Be sure to add the underscore to sample_identity_diffusion to become _sample_identity_diffusion and maybe remove the docstring too