pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.55k stars 985 forks source link

HMM model using HMC/NUTS is slow #1511

Closed neerajprad closed 3 years ago

neerajprad commented 6 years ago

Since we started using einsum to evaluate the joint log density of the trace with discrete parameters, we can do discrete site enumeration across many more models in HMC/NUTS without going OOM.

For HMMs, however, NUTS / HMC is extremely slow to the point of being unusable beyond a few time steps, say 10. Refer to this test for profiling.

There are a few issues that I noticed:

screen shot 2018-11-01 at 5 24 41 pm
fehiepsi commented 6 years ago

@neerajprad I tested with pytorch 0.4.0 and didn't catch the slowness. As I mentioned in #1487, trace + Gamma/Dirichlet + pytorch 1.0rc will give wrong grads in backward pass. And that bug is not related to HMC/NUTS. I don't know what I can do more with that issue so I skip it and implement statistics such as effective number of samples / Gelman-Rubin convergence diagnostic instead.

fehiepsi commented 6 years ago

About initialization, Stan does not use MAP. Previously, PyMC3 used MAP, but now they don't by default. There might be experimental reasons for their decisions (which I can't track back). In my opinion, allowing users set starting points is enough (from my experience, it is extremely useful for some models: when I got stuck with randomize initialization, I set initializations to the mean, then things go smoothly). These starting points can come from intuition or from mean of priors or from MAP.

neerajprad commented 6 years ago

@fehiepsi - This is a known issue that we would like to address or at least conclude that we cannot run HMM type models with the current constraints. I just cc'd you as an FYI - you shouldn't feel compelled to work on this! :)

I tested with pytorch 0.4.0 and didn't catch the slowness.

The HMM test is only sampling 10 traces, but if you run it for longer than 10 time steps, you will find the issue of step sizes getting very small and making extremely slow progress. This is without using JIT. My hunch was that this could be the case with a bad initialization, and with transforms warping the potential energy surface in a way that trajectories are extremely unstable, and we keep lowering the step size, making progress extremely slow. This is just a guess though and needs to be investigated further.

Stan does not use MAP. Previously, PyMC3 used MAP, but now they don't by default.

Even if it is not available by default, I am interested in exploring if initializing with the MAP estimate does better on these kinds of models. If so, it will be useful to provide an optional kwarg initialize_with_map=False to the HMC/NUTS kernels.

implement statistics such as effective number of samples / Gelman-Rubin convergence diagnostic instead.

That will be really useful! You should also check out arviz (which implements traceplot and diagnostics like gelman rubin), and this PR https://github.com/arviz-devs/arviz/pull/309 by @ColCarroll, which extends support for Pyro.

fehiepsi commented 6 years ago

The HMM test is only sampling 10 traces, but if you run it for longer than 10 time steps, you will find the issue of step sizes getting very small and making extremely slow progress.

You are right that this might another issue. I did the test for 100 num_samples and 100 warmup_steps in pytorch 1.0rc and pytorch 0.4. Pytorch 1.0rc is a bit slower than pytorch 0.4. Einsum is slower than not einsum. But I didn't observe the very small step_size problem. The step_size is around 0.0001-0.0004 in all the test.

For MAP, I discourage to use it with HMC. I have implemented dozens of models on various small dataset. Unless I specified good initial values, MAP gives very bad answers despite that I have set different learning rate and num_steps for MAP. For example, a simple linear regression with Gaussian likelihood: y = Normal(ax + b, sigma). When scale of a is large, we have to set learning rate to a large value unless we run SVI with dozens of thousands steps. But with large learning rate, sigma will tend to move to a very large value! At the end of MAP, I cann't get the answer I need. So the performance is heavily depending on the initial values of MAP, learning rate, num steps,...

I don't face problems (other than the nan issue which we have addressed) with initial trace so I can't say much. To avoid extreme values of initial trace (due to random initialization), Stan approach for initialization might be helpful. They initialize values randomly in the interval (-2, 2) of unconstrained space.

You should also check out arviz (which implements traceplot and diagnostics like gelman rubin), and this PR arviz-devs/arviz#309 by @ColCarroll, which extends support for Pyro.

Thanks for your suggestion!!! I don't know about it. So if arviz already supports these diagnostics, then should we implement it? It might be better to explore how to combine things with arviz instead. What do you think? Edit: I took a look at their implementation. All calculation depends on numpy, which is a little bit uncomfortable to me. I will implement these diagnostics in pytorch instead.

fritzo commented 6 years ago

@neerajprad could you send me a .prof file and/or the steps to reproduce your profile? I'd like to inspect the profiling numbers in contract_to_tensor().

neerajprad commented 6 years ago

@neerajprad could you send me a .prof file and/or the steps to reproduce your profile? I'd like to inspect the profiling numbers in contract_to_tensor().

Thanks @fritzo, I will send you the profiling script and .prof file shortly.

I don't face problems (other than the nan issue which we have addressed) with initial trace so I can't say much. To avoid extreme values of initial trace (due to random initialization), Stan approach for initialization might be helpful. They initialize values randomly in the interval (-2, 2) of unconstrained space.

Thanks for all the suggestions, @fehiepsi. I will play around with different initializations first to see if it improves the performance.

I don't know about it. So if arviz already supports these diagnostics, then should we implement it? It might be better to explore how to combine things with arviz instead. What do you think?

I think if the integration is straightforward and arviz has all the diagnostics you were looking to implement, we could just suggest users to go with that (and even add it to our example). If you find the integration lacking in any way, and have ideas on what can be improved, feel free to open an issue to discuss! I think you might need to change the interface a bit (to preserve chain information). I am not too worried about numpy because that conversion will happen at the end of inference (unless you'd like to provide some online diagnostics) and converting a cpu tensor to numpy is low overhead.

neerajprad commented 6 years ago

@fritzo - You can run the profiler using:

python -m cProfile -o hmm.prof tests/perf/prof_hmc_hmm.py

on the prof-hmm branch. Attaching the .prof file. I have turned off step size adaptation so as not to take too much time. Most of the time is actually just taken by einsum so I am not sure if there is much room for optimization here.

hmm.tar.gz

ColCarroll commented 6 years ago

The only thing I think to worry about in using ArviZ is that we are writing the library with no regard for Python 2. In particular, we use matplotlib 3.x, which is Python3 only, and the rest of the python data science infrastructure seems to be phasing python2 support out over the next year, so we did not want to start a new project with that technical debt. I understand this may hurt adoption in legacy stacks!

Beyond that, please tag one of us here or open an issue if we can help with the integration at all. We have found xarray and netcdf to be very natural ways of storing inference data.

neerajprad commented 6 years ago

@ColCarroll - Thanks for offering to help with the integration! Regarding the python 2 incompatibility, there is already another feature (CUDA with parallel MCMC chains) that isn't supported on Python 2. Given that Python 2 will not be supported in a year or so, I think it is fine if certain (non-critical) features are only available in Python 3 going forward, but this is worth discussing internally for sure.

EDIT: This however means that we cannot have arviz as an official dependency until we drop support for python 2.

fritzo commented 6 years ago

@ColCarroll we plan for Pyro to support Python 2 as long as PyTorch supports Python 2.

fehiepsi commented 5 years ago

@neerajprad Could you point me against the profiling test which is slow? It seems that the file tests/perf/prof_hmc_hmm.py is not available in dev branch.

neerajprad commented 5 years ago

@fehiepsi - I updated the example in the prof-hmm branch. It should be in tests/perf/prof_hmc_hmm.py. I don't think there are any immediate TODOs for this one, and this is more of an enhancement issue than a perf issue. Some things we can experiment with in the future would be JITing the grad step itself (once PyTorch supports it).

fehiepsi commented 5 years ago

@neerajprad Totally agree! I just do profiling with both jit and nojit (hmm_2.zip). Most of time is spent for computing _potential_energy and its grad, so the slowness is not related to hcm/nuts.

It is surprised to me that distributions' log_prob just take 40s in the total 250s to compute trace_log_prob. Lots of time is spent on sumproduct (140s) and pack_tensors (40s) stuffs. I guess this is expected? ~In addition, get_trace spends a lot of time on process_message stuffs and post_process_message stuffs (more than 100s). I believe to get samples for this small model, we just need 1s (in 100s) for sampling.~

fritzo commented 5 years ago

It is surprised to me that distributions' log_prob just take 40s

In the HMM example the distribution log_prob computation is merely a gather, i.e. memcopy; all actual computation is done by sumproduct when combining log_prob tensors from multiple sites, i.e matmul and einsum.

fehiepsi commented 5 years ago

@fritzo That makes sense! So the slowness is expected for models with discrete variables.

fehiepsi commented 5 years ago

I put here a profiling work https://gist.github.com/fehiepsi/75dfbea31b993f165f51524776185be6 for reference.

For the same model, Pyro took 66s while Stan took 33s. The inference results are quite comparable.

But the important point is: it took 32s to compile the Stan model, and only 1s for sampling! So compiling plays an important job here. Hope that PyTorch JIT will be improved in the future. :)

eb8680 commented 5 years ago

~@fehiepsi the difference may be related to this PyTorch issue about einsum performance: https://github.com/pytorch/pytorch/issues/10661~ oops, there's no enumeration happening in supervised_hmm, so this might not be a problem

elbamos commented 5 years ago

I'm jumping in here because I've also seen serious performance issues with NUTS/HMC.

In my testing, performance starts out adequate, and then begins to drop precipitously after around 20 iterations. I have observed as this is occurring, the step size is increasing.

To me, the interesting thing is that the performance isn't constant. It declines after time. That suggests to me that the issue is not limited to the time it takes to calculate the energy, which shouldn't vary that much from iteration to iteration.

I have two suspicions: The first is that the HMC/NUTS implementation is trying too hard to increase the step size, and so that it ends up producing lots and lots of divergences. The second is that this has to do with memory fragmentation because of the very large number of tensors that are created as intermediate steps and then retained through gradient calculation.

fehiepsi commented 3 years ago

I believe the slowness is expected when running Pyro MCMC on markov models. So I would like to close this issue. We can point users to numpyro if speed is needed.