probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
667 stars 76 forks source link

Implemented marginal log likelihood for parallel Kalman #243

Closed matthew9671 closed 1 year ago

matthew9671 commented 1 year ago

Calculation details: https://www.overleaf.com/read/mwsmdyqycynn

michaeltm365 commented 1 year ago

Looks great, thanks! Maybe you add add the math writeup to a gist (or an arxiv note), and add a link to it?

murphyk commented 1 year ago

The comments from "michael" are actually from kevin - github has messed up our id's!.

AdrienCorenflos commented 1 year ago

Cool trick! I know for a fact this must have been painful to derive XD Did you check if/when it does better than the stupid "just compute it after the fact" method?

matthew9671 commented 1 year ago

Cool trick! I know for a fact this must have been painful to derive XD Did you check if/when it does better than the stupid "just compute it after the fact" method?

Could you clarify what you mean by that?

I guess one thing to note is that I've found the log normalizer computation to be numerically very sensitive, so it will give different results for mathematically equivalent formulations. I don't have a ton of experience in numerical stuff like this so I think there definitely could be room for improvement/more experimentation.

AdrienCorenflos commented 1 year ago

I typically compute the filtering mean and cov, then compute the observation predictive mean and cov, and compute the log-likelihood elementwise using vmap, like here https://github.com/EEA-sensors/sqrt-parallel-smoothers/blob/16d9bbd5aa021be5f88277d3c87242b2c9a7bd53/parsmooth/parallel/_filtering.py#L57

I was wondering how efficient the "explicit" method was compared to doing this crude thing

matthew9671 commented 1 year ago

Interesting. Haven't done this comparison in terms of numerical accuracy and computational time. Definitely interesting to try out since this should have the same computational complexity...

slinderman commented 1 year ago

That’s a good point. You can compute log p(yt | y{1:t-1}) in parallel for all t. I’m not sure areay sums automatically use a parallel reduction on GPUs, but they certainly could. In any case, I’d be surprised if the final summation was the most costly computation.

On Nov 3, 2022, at 2:25 PM, matthew9671 @.***> wrote:  Interesting. Haven't done this comparison in terms of numerical accuracy and computational time. Definitely interesting to try out since this should have the same computational complexity...

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you are subscribed to this thread.