Closed Jacob-Stevens-Haas closed 2 months ago
don't think there's much that you can do since presumably much of that time is going into jax compilation
Closed because this is more specific to jax. Please checkout https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html. As far as I know, it supports TPU and GPU for now and will support CPU in near future.
Thanks! I'll check it out
In a package that builds regression problems with a variety of solvers, we recently added numpyro and a regularized horseshoe prior. The tests however, take longer than any other method, and I'm trying to speed them up. Even a small test (num_warmup=1, num_samples=4, data shape =(10, 2)) takes around 5 seconds - substantially longer than our tests for discrete optimizations with gurobipy. Fixing the geometry would probably help in this specific case, but in general, are there any kwargs that will let NUTS/MCMC find a very quick (and bad) solution? Like setting
max_iter=1
in a gradient descent algorithm. I tried settingmax_tree_depth
to 1 andtarget_accept_prob
to .1, but that didn't change timing appreciably. This would be useful in tests like "make sure that we can pickle our models after fitting" and "make sure that a fit model has certain attributes".In profiling, the test spends about 41% the time initializing,: 10% hmc:186:initialize kernel 31% util:303:find_valid_initial_params and 44% util:266:fori_collect (I assume actual execution)
Totally understand if it's not available, though! - it's a great package as is.
Example code: