probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
634 stars 70 forks source link

dynamax is weirdly slow in my HMM benchmark #359

Open gdalle opened 4 months ago

gdalle commented 4 months ago

Hi, and congrats on the amazing package!

I have developed an HMM library in Julia called HiddenMarkovModels.jl, and I am currently benchmarking it against the Python alternatives (see here for the feature comparison). I want to benchmark fairly but I'm a JAX newbie, so I was wondering if someone might advise me on possible suboptimalities in my dynamax code?

My test case is an HMM with scalar Gaussian emissions and 100 sequences of length 200 each. I'm interested in small-ish models, which is also why I run everything on the CPU. When I time the forward, forward-backward and Viterbi algorithms, dynamax is among the fastest packages. However I observe a significant slowdown in the EM algorithm, so perhaps something is wrong there (see plots below). The three inference algorithms have been jit-ed and vmap-ed for multiple sequences, but I don't know if I can do the same with EM learning. Any suggestions are welcome!

forward-1

baum_welch-1

The benchmark is run from Julia with PythonCall.jl, so don't freak out at the weird syntax. Here are the main bits: