Closed rlouf closed 4 years ago
Note for future self: The slowdown is more due to multimap
than np.stack
, stacking ravelled states is much faster. The idea is now to ravel the states and append them to a list, stack the ravelled states then unravel in a vmap
. If that observation is confirmed, our choice of dealing with flat arrays in the core may have compounding pay offs, even though it only removes a few multimaps
. We should definitely add a notebook in design_notes that compares both methods for future reference.
While sampling with
scan
is extremely performant, sampling with a for loop (to display a progress bar) is slower than expected. Building the chain is fairly fast, but MCX spends a considerable amount of time converting the chain in thefor-loop
format into thescan
format, which is more appropriate to build the trace.Presently the chain is a list of states
[(state, info)]
and we usemultimap
andnp.stack
to stack the chain. This is relatively slow possibly due to memory allocation. However we do know how much memory needs to be allocated before sampling, we can thus pre-initialize an array with zeros and fill the rows as looping. This way the progress bar will also be a better indicator of the "real" sampling time.Since the reference is
scan
's output it is worth investigating the internals.This PR addresses issue #41.