When we sample with accelerate=False mcx moves the chains forward in a python for loop to be able to display a progress bar. Unlike when using scan, we need to stack the state and info values to create a trace. We do so with the following code:
The sampling phase is fairly fast and most of the time is spent stacking the values overall making the process slow. We observe a 12x running time difference with scan.
The issue probably comes from the time taken to allocate new memory with np.stack. The first thing to try is probably to flatten the (state, info) tuples, initialize an array of zeros with num_samples as the first dimension and use ops.index_update to append the sample, and unpack at the end of sampling.
When we sample with
accelerate=False
mcx moves the chains forward in a python for loop to be able to display a progress bar. Unlike when usingscan
, we need to stack thestate
andinfo
values to create a trace. We do so with the following code:The sampling phase is fairly fast and most of the time is spent stacking the values overall making the process slow. We observe a 12x running time difference with
scan
.The issue probably comes from the time taken to allocate new memory with
np.stack
. The first thing to try is probably to flatten the(state, info)
tuples, initialize an array of zeros withnum_samples
as the first dimension and useops.index_update
to append the sample, and unpack at the end of sampling.