rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
324 stars 17 forks source link

Performance issue when sampling without acceleration #41

Closed rlouf closed 4 years ago

rlouf commented 4 years ago

When we sample with accelerate=False mcx moves the chains forward in a python for loop to be able to display a progress bar. Unlike when using scan, we need to stack the state and info values to create a trace. We do so with the following code:

stack = lambda y, *ys: np.stack((y, *ys))
chain = tree_multimap(stack, *chain)

The sampling phase is fairly fast and most of the time is spent stacking the values overall making the process slow. We observe a 12x running time difference with scan.

The issue probably comes from the time taken to allocate new memory with np.stack. The first thing to try is probably to flatten the (state, info) tuples, initialize an array of zeros with num_samples as the first dimension and use ops.index_update to append the sample, and unpack at the end of sampling.

rlouf commented 4 years ago

Problem solved, closing.