Closed SamDuffield closed 3 months ago
Update: associative_scan
now working but seems like something is wrong with the calculations so I need to check the maths again
Ok I fixed the maths! At the cost of doubling the number of expm_vp
calls, we might be able to halve it again with further thought although I'm not sure.
Next step is to add associative_scan
for log_prob
Finished the speedup comparison with and without associative scan and adapted the handling of random keys in _sample_identity_diffusion
, should be good to review
Be sure to add the underscore to sample_identity_diffusion
to become _sample_identity_diffusion
and maybe remove the docstring too
First attempt at using
jax.lax.associative_scan
#14 , but it's throwing a matmul contracting dimensions error and I'm not sure why.