rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Problem running main example #19

Closed ericmjl closed 4 years ago

ericmjl commented 4 years ago

Hi @rlouf!

I tried out the example on the README, and I encountered some issues.

Apart from some syntax issues (which were not too difficult to fix), I saw issues with sampling.

Here's a link to my notebook: https://gist.github.com/ericmjl/f80c76bd13c65b903078881e9485c7c4

The thing that baffled me was that I thought the sampler should have yielded a fuzzy caterpillar trace, but instead yielded a single number sampled over and over.

rlouf commented 4 years ago

Hi @ericmjl,

Thank you for raising an issue and being adventurous enough to test drive before the initial release!

First, did you run this on GPU? The number of samples/s is high and if ran on CPU this may be an issue with how I generate random numbers (although it has always worked on my end).

Also, the tqdm glitch makes me think that I haven't pushed the latest changes on master yet.

I will look into it asap and come back to you. Once this is fixed I'd be happy to have your feedback on the API's feel.

ericmjl commented 4 years ago

@rlouf thanks for getting back!

I ran this on CPU, and indeed, random number generation might be the issue. I'm sure you know this, but just checking - you remember to split the PRNGKey, is that right?

rlouf commented 4 years ago

Yes, I do split the keys. I double-checked the code on master quickly and I am not sure where the issue could be. I suspect in the sampling.py file, but I’m not sure. A quick way to check would be to look at the proposed_state field in the HMCInfo struct returned by the kernel (not exposed publicly). If they’re always the same it’s indeed an issue with the RNG.

Otherwise it may also be an issue with the number of steps/step size set too high. HMCInfo also contains a is_divergent field (exposed in trace['info']['is_divergent']). I’m finishing the implementation of the Stan warmup at the moment so if that’s the case it should be less of a problem later.

I’ll dive deeper later today or tomorrow and let you know.

rlouf commented 4 years ago

@ericmjl I found the source of the bug, it is dumber than I thought. For some reason I removed the minus sign in front of the likelihood to create the potential. The rationale was probably that this is HMC-specific (potential is minus log-likelihood) and shouldn't be in sampling.py. Time to add some tests on this :)

Anyway, I fixed it. I'll do some cleaning and push the changes later. The number of samples/s shouldn't change after the fix. In fact from what I saw on my machine the number of samples/s is still very large.

rlouf commented 4 years ago

Fixed in the last commit. As you will see, it is still as fast. Thank you for raising the issue!

ericmjl commented 4 years ago

Thanks @rlouf! That was speedy.

I still think, though, there might be something I'm missing here. I noticed that of the four MCMC chains in an updated version of the notebook with updated visual diagnostics, only one of the four chains sampled near the "true" value; others were flat. I am quite confident that I correctly updated to the latest version of master. Is this something you can reproduce with the notebook on GitHub Gist?

rlouf commented 4 years ago

I’ll check later today, it might have something to do with the HMC parameters, namely step size and number of steps being too large—I did change them when I was debugging.

Also, I don’t know if you saw (given it’s not documented yet) but doing “sample.run()” keeps moving the chain forward from where it previously stopped so you don’t need to start from zero if you want more samples from the chain.

Note that the iterative sampler “generate” is a good way to debug as it returns more information than “sample”

rlouf commented 4 years ago

There was small glitch in the core (too long to explain) which was an easy fix, but also an issue with the parameters of the HMC kernel. The following kernel should work:

kernel = mcx.HMC(
    step_size=0.00001,
    num_integration_steps=90,
    mass_matrix_sqrt=np.array([1.1, 1.2]),
    inverse_mass_matrix=np.array([1., 1.]),
)

Why a step size so small?

I currently pick initial points by drawing from the prior distribution and chains can start from very different points in the state space. In this case, the chain with index 1 starts from

coefs = -0.36427957
scale = 0.09018712

which leads to very high values of the potential and its gradients; as a result unless you are lucky and draw an very small value of the momentum, it overshoots and the trajectory diverges. Decreasing the step size rescales that and allows the chain to move out.

How to avoid this in the future?

I don't know how this is done in other libraries. Naively I see two options:

  1. Start every chain from the (roughly) the same point.
  2. Run one warmup per chain (does not take much longer thanks to vmap).

I have a preference for option (2), intuitively it would make it easier to detect bad mixing.

rlouf commented 4 years ago

Will close once I’ve added a test that reproduces the original issue, unless there is something else?

I will open a different issue for the warmup.

rlouf commented 4 years ago

Added a test in #22, closing.