Closed ColCarroll closed 9 months ago
flowMC assumes the likelihood function with the following signature:
def logp(pts: jnp.array, data: PyTree) -> float:
...
where pts should be the parameters you want to sample, and data is a pytree containing auxiliary data you don't have to sample over.
So in order to get this to work, I think you need to modify your likelihood defined around
def log_prob(pts):
print("HEY!", pts)
return bx_model.log_density(pts)
in two ways:
avg_effect
, avg_stddev
, and school_effects
, which in total are 80 parameters, then your input would be a (80,) jax array, then within the likelihood you need to restructure the parameters into struct tuple in order to pass it to your model. I am not familiar with struct tuple from tfp, so I don't know how to do that, but that should be a good start.log_prob(init)
returns 8 numbers, which makes me think you actually have less parameters than 80. Ideally bx_model.log_density(pts)
should return a single scalar. flowMC handles the vmapping under the hood, so the likelihood should be a function only with one chain instead of all the chains.To be more concrete, the init points should have a shape (n_chains, n_dim)
, where n_chains
is the number of chains, and n_dim
is the dimension of the parameters you want to sample over.
Let me know whether this helps resolving the problem. If this works in the end, would you mind if I link this example on our doc page so others can take a look of this as well?
P.S. Would you mind pointing me to the LearnBayesStat episode?
Thanks for the pointers!
I was able to get my example running -- I have to do some funny things to get my state flat (and unflat) which I would guess hurts performance, particularly on accelerators where reshape
is not free.
Performance is also terrible-ish, which I assume is user error -- here is the example with 12 chains:
I updated the colab with the code that actually runs. I'll keep looking at this tomorrow, but continue to appreciate suggestions for improving performance if you have any: bayeux
tries to set generally sensible defaults (which are also provided to the user). I would guess that I need to adjust these defaults to get better sampling! https://colab.research.google.com/drive/1-PhneVVik5GUq6w2HlKOsqvus13ZLaBH?usp=sharing
For what it is worth, I included samples from numpyro
's NUTS sampler at the bottom -- a well tuned sampler does 7 HMC steps with step size 0.6.
Update: I ran with HMC (which secretly requires a condition_matrix
argument -- np.eye(n_dim)
is a sensible default) using the numpyro
tuning arguments and added more epochs:
Does the fact that I am ignoring the data
argument matter here? bayeux
assumes that the user will just close over data, like log_density = functools.partial(log_prob, data=data)
.
Also, the podcast is here: https://open.spotify.com/episode/1wRsmH8xXTpO8JOWajgWpL?si=df762a337c7d4361 (the website https://learnbayesstats.com/ has not updated with @marylou-gabrie's episode). She did an excellent job describing the algorithm, and I'm hoping bayeux
will allow users of other PPLs to give flowMC a try. Also, I personally hate benchmarks for MCMC and appreciated Dr. Gabrié's nuance there! I'll certainly let you know if/when this merges over there!
Ok, got a little nerd-sniped by this, but it looks like setting more local steps and more loops gets reasonable performance:
n_local_steps = 200,
n_global_steps = 50,
n_loop_production=4,
which compares reasonably well with numpyro
I'm going off the podcast here, but I suppose this is a problem that is well suited to HMC, rather than being some wild statistical mechanics problem with symmetries to deal with? I'll try to cook up one of those tomorrow along with a PR to make this a little more ergonomic.
I played with the notebook a bit, and here is an updated version https://colab.research.google.com/gist/kazewong/033c89e548ef59b3ceb649dcf2ffe9e5/bayeux_and_flowmc.ipynb
Here are some of the changes:
n_chains
to 24. Just playing around with the number, in general, the more the better, and the runtime usually doesn't change on a GPU unless one starts to saturate the compute or memory bandwidth of the GPU.print(nf_sampler.get_sampler_state(training=True)['loss_vals'].min(),global_accs.mean())
. The local sampler acceptance shows whether the local sampler is reasonable, anywhere between 0.2-0.8 is more less okay. The global sampler acceptance shows how well the flow has been trained to approximate the target, the higher the better. In the original notebook was about 0.02, now it is 0.44.n_loop_training
and n_loop_production
, they basically control how many times the sampler alternates between the local sampler and global sampler. The larger they are the longer the run time since you are asking for more samples, but that helps with training and producing more samples. I should make this more clear in the tuning guide (writing one soon!)With all these changes, I think the flowMC result is more reasonable. Now there is one more problem that is actually interesting and I had it in the back of my mind but never really finished it.
With everything else being the same as shown in the notebook, this is combining all the chains when I put n_loop_production=10
This is when I use n_loop_production=30
You can see the stripes are distributed in the larger n_loop_production
case. The performance of the flow is the same since the training phase is the same for the two, only the length of the production phase is changed.
The reason for this is probably because the samples produced by the local sampler are correlated while the ones by the global sampler are uncorrelated ( or way less correlated). Since n_local_steps
and n_global_steps
are the same, this means the global sampler will fly the chain around for 50 steps per loop, then the local sampler will jitter the chain locally for 50 steps per loop. This probably caused the extra cluster of points around each stripe, and as we increase the n_loop_production
, the global sampler brings the chains to more places, creating more densely sampled stripes.
To solve this, I think there is actual work to do. Basically, we want effective samples with the local sampler instead of every sample, which should provide way smoother posterior and take away the stripes artifacts.
Last remark, I agree this is a problem HMC can probably solve quite well. flowMC added the extra layer of normalizing flow to deal with bad geometries, such as multimodality or really stretched out and local correlation (like a donut). This problem is rather unimodal and smooth, so HMC shouldn't have a hard time dealing with it.
Please let me know if you have more problems regarding this, I am also happy to help make this an example so other users can follow the logic behind this discussion.
Thanks again -- working on the PR now.
It should do even slightly better than above, since bayeux has machinery to transform the support of models to all of R^n -- right now flowMC has no idea school_effects
must be positive (other than the nan
log density), while numpyro
has the advantage of having that parameter transformed by softmax
.
https://github.com/jax-ml/bayeux/pull/23 is out now -- if you have time for comments/suggestions, please do!
There's a fair amount of abstraction going on, and it may be easier to play around with it after it merges, then open one or more issues! I'll follow up with a colab using bayeux
.
A few notes from doing this -- lmk if you'd like these to be separate issues:
params
or kwargs
argument makes it difficult for static checkers to make sure I am using the right arguments. I sort of assume that if flowMC
gets updated, bayeux
will break -- it seems like you could expand the signature of the nf_models and local_samplers? n_layer
vs n_layers
and hidden_size
vs n_hidden
)Sampler
has enough information to handle the key splitting itself? https://colab.research.google.com/drive/1-PhneVVik5GUq6w2HlKOsqvus13ZLaBH?usp=sharing is updated with flowMC
, including a multimodal distribution towards the end where it seems to do better than numpyro (for some value of "better"!)
I'll probably add an example notebook based on that soon. First I have to fix the bug that doesn't allow setting keyword arguments!
jax-ml/bayeux#23 is out now -- if you have time for comments/suggestions, please do!
There's a fair amount of abstraction going on, and it may be easier to play around with it after it merges, then open one or more issues! I'll follow up with a colab using
bayeux
.A few notes from doing this -- lmk if you'd like these to be separate issues:
* requiring a `params` or `kwargs` argument makes it difficult for static checkers to make sure I am using the right arguments. I sort of assume that if `flowMC` gets updated, `bayeux` will break -- it seems like you could expand the signature of the nf_models and local_samplers? * related, a few of these keyword arguments are similar, but slightly different (`n_layer` vs `n_layers` and `hidden_size` vs `n_hidden`) * the random_key_set seems like it is probably an anti-pattern -- in particular, I would like to pass in a jax prngkey and have everything "just work" (or even pass such a key to the helper function instead of an int). it seems like maybe `Sampler` has enough information to handle the key splitting itself?
params
in different local_sampler? Currently that is used to maintain a somewhat more unified API across different local_sampler. When I was making this code a year ago I wasn't paying much attention to typing so it is rather unsatisfactory in the way it currently is now. It seems a bit tricky to me how to handle different params with a static checker while providing a unified API to flowMC since different local sampler might have different number of params, coming in different shape and type. There might be a solution lying in some examples of equinox. I will see what I can do about this. The minimum will probably a run-time check during initialization.
Hey! I heard about this library on the LearnBayesStat podcast, and was trying to integrate it with https://jax-ml.github.io/bayeux/. It seems like it should be easy, since both worth with just a log density and an initial point. However, I am getting a somewhat cryptic error from a
jnp.stack
insideflowMC
.I would guess it has to do with the fact that the initial state has shape [(), (), (8,)], and so 8 chains have shape [(8,), (8,), (8, 8)], and there is some problem that the last dimension has more dimensions than the other two.
Here's the colab I've been trying to get this to work in: https://colab.research.google.com/drive/1-PhneVVik5GUq6w2HlKOsqvus13ZLaBH?usp=sharing
Any help would be much appreciated!