blackjax-devs / sampling-book

Tutorials and sampling algorithm comparisons
https://blackjax-devs.github.io/sampling-book
68 stars 13 forks source link

extend_params fix after n_particles removed #62

Open andrewdipper opened 4 months ago

andrewdipper commented 4 months ago

Fix examples after change to extend_params in https://github.com/blackjax-devs/blackjax/pull/694.

Additionally in the TemperedSMC example max_num_doublings was changed to 6 instead of the default 10 since we regularly hit max_num_doublings due to the small step size (I believe this is for illustrative purposes). On a gpu device the example is extraordinarily slow without the change - and still takes ~2 mins with it. It seems far too slow but I haven't been able to find any explanation.

For reference:

CPU (10000 samples, max_num_doublings=10): step_size = 1e-2: HMC: 50 steps / 1.14s NUTS: 30 steps / .964s

step_size = 1e-3 HMC: 50 / 1.14s NUTS: 273 / 1.9s

step_size = 1e-4 HMC: 50 / 1.18s NUTS: 926 / 4.23s

GPU (1000 samples - 10x fewer samples..., max_num_doublings=10): step_size = 1e-2: HMC: 50 / 3.31s NUTS: 30 / 7.3s

step_size = 1e-3: HMC: 50 / 3.32s NUTS: 267 / 63s

step_size = 1e-4 HMC: 50 / 3.31s NUTS: 926.4 / 237s