rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
324 stars 17 forks source link

Fix performance issue when sampling with `for loop` (#41) #42

Closed rlouf closed 4 years ago

rlouf commented 4 years ago

While sampling with scan is extremely performant, sampling with a for loop (to display a progress bar) is slower than expected. Building the chain is fairly fast, but MCX spends a considerable amount of time converting the chain in the for-loop format into the scan format, which is more appropriate to build the trace.

Presently the chain is a list of states [(state, info)] and we use multimap and np.stack to stack the chain. This is relatively slow possibly due to memory allocation. However we do know how much memory needs to be allocated before sampling, we can thus pre-initialize an array with zeros and fill the rows as looping. This way the progress bar will also be a better indicator of the "real" sampling time.

Since the reference is scan's output it is worth investigating the internals.

This PR addresses issue #41.

rlouf commented 4 years ago

Note for future self: The slowdown is more due to multimap than np.stack, stacking ravelled states is much faster. The idea is now to ravel the states and append them to a list, stack the ravelled states then unravel in a vmap. If that observation is confirmed, our choice of dealing with flat arrays in the core may have compounding pay offs, even though it only removes a few multimaps. We should definitely add a notebook in design_notes that compares both methods for future reference.