Closed neerajprad closed 3 years ago
@neerajprad I tested with pytorch 0.4.0 and didn't catch the slowness. As I mentioned in #1487, trace + Gamma/Dirichlet + pytorch 1.0rc will give wrong grads in backward pass. And that bug is not related to HMC/NUTS. I don't know what I can do more with that issue so I skip it and implement statistics such as effective number of samples / Gelman-Rubin convergence diagnostic instead.
About initialization, Stan does not use MAP. Previously, PyMC3 used MAP, but now they don't by default. There might be experimental reasons for their decisions (which I can't track back). In my opinion, allowing users set starting points is enough (from my experience, it is extremely useful for some models: when I got stuck with randomize initialization, I set initializations to the mean, then things go smoothly). These starting points can come from intuition or from mean of priors or from MAP.
@fehiepsi - This is a known issue that we would like to address or at least conclude that we cannot run HMM type models with the current constraints. I just cc'd you as an FYI - you shouldn't feel compelled to work on this! :)
I tested with pytorch 0.4.0 and didn't catch the slowness.
The HMM test is only sampling 10 traces, but if you run it for longer than 10 time steps, you will find the issue of step sizes getting very small and making extremely slow progress. This is without using JIT. My hunch was that this could be the case with a bad initialization, and with transforms warping the potential energy surface in a way that trajectories are extremely unstable, and we keep lowering the step size, making progress extremely slow. This is just a guess though and needs to be investigated further.
Stan does not use MAP. Previously, PyMC3 used MAP, but now they don't by default.
Even if it is not available by default, I am interested in exploring if initializing with the MAP estimate does better on these kinds of models. If so, it will be useful to provide an optional kwarg initialize_with_map=False
to the HMC/NUTS kernels.
implement statistics such as effective number of samples / Gelman-Rubin convergence diagnostic instead.
That will be really useful! You should also check out arviz (which implements traceplot and diagnostics like gelman rubin), and this PR https://github.com/arviz-devs/arviz/pull/309 by @ColCarroll, which extends support for Pyro.
The HMM test is only sampling 10 traces, but if you run it for longer than 10 time steps, you will find the issue of step sizes getting very small and making extremely slow progress.
You are right that this might another issue. I did the test for 100 num_samples and 100 warmup_steps in pytorch 1.0rc and pytorch 0.4. Pytorch 1.0rc is a bit slower than pytorch 0.4. Einsum is slower than not einsum. But I didn't observe the very small step_size problem. The step_size is around 0.0001-0.0004 in all the test.
For MAP, I discourage to use it with HMC. I have implemented dozens of models on various small dataset. Unless I specified good initial values, MAP gives very bad answers despite that I have set different learning rate and num_steps for MAP. For example, a simple linear regression with Gaussian likelihood: y = Normal(ax + b, sigma). When scale of a
is large, we have to set learning rate to a large value unless we run SVI with dozens of thousands steps. But with large learning rate, sigma
will tend to move to a very large value! At the end of MAP, I cann't get the answer I need. So the performance is heavily depending on the initial values of MAP, learning rate, num steps,...
I don't face problems (other than the nan
issue which we have addressed) with initial trace so I can't say much. To avoid extreme values of initial trace (due to random initialization), Stan approach for initialization might be helpful. They initialize values randomly in the interval (-2, 2) of unconstrained space.
You should also check out arviz (which implements traceplot and diagnostics like gelman rubin), and this PR arviz-devs/arviz#309 by @ColCarroll, which extends support for Pyro.
Thanks for your suggestion!!! I don't know about it. So if arviz already supports these diagnostics, then should we implement it? It might be better to explore how to combine things with arviz instead. What do you think? Edit: I took a look at their implementation. All calculation depends on numpy
, which is a little bit uncomfortable to me. I will implement these diagnostics in pytorch instead.
@neerajprad could you send me a .prof file and/or the steps to reproduce your profile? I'd like to inspect the profiling numbers in contract_to_tensor()
.
@neerajprad could you send me a .prof file and/or the steps to reproduce your profile? I'd like to inspect the profiling numbers in contract_to_tensor().
Thanks @fritzo, I will send you the profiling script and .prof file shortly.
I don't face problems (other than the nan issue which we have addressed) with initial trace so I can't say much. To avoid extreme values of initial trace (due to random initialization), Stan approach for initialization might be helpful. They initialize values randomly in the interval (-2, 2) of unconstrained space.
Thanks for all the suggestions, @fehiepsi. I will play around with different initializations first to see if it improves the performance.
I don't know about it. So if arviz already supports these diagnostics, then should we implement it? It might be better to explore how to combine things with arviz instead. What do you think?
I think if the integration is straightforward and arviz has all the diagnostics you were looking to implement, we could just suggest users to go with that (and even add it to our example). If you find the integration lacking in any way, and have ideas on what can be improved, feel free to open an issue to discuss! I think you might need to change the interface a bit (to preserve chain information). I am not too worried about numpy because that conversion will happen at the end of inference (unless you'd like to provide some online diagnostics) and converting a cpu tensor to numpy is low overhead.
@fritzo - You can run the profiler using:
python -m cProfile -o hmm.prof tests/perf/prof_hmc_hmm.py
on the prof-hmm
branch. Attaching the .prof file. I have turned off step size adaptation so as not to take too much time. Most of the time is actually just taken by einsum so I am not sure if there is much room for optimization here.
The only thing I think to worry about in using ArviZ is that we are writing the library with no regard for Python 2. In particular, we use matplotlib 3.x, which is Python3 only, and the rest of the python data science infrastructure seems to be phasing python2 support out over the next year, so we did not want to start a new project with that technical debt. I understand this may hurt adoption in legacy stacks!
Beyond that, please tag one of us here or open an issue if we can help with the integration at all. We have found xarray and netcdf to be very natural ways of storing inference data.
@ColCarroll - Thanks for offering to help with the integration! Regarding the python 2 incompatibility, there is already another feature (CUDA with parallel MCMC chains) that isn't supported on Python 2. Given that Python 2 will not be supported in a year or so, I think it is fine if certain (non-critical) features are only available in Python 3 going forward, but this is worth discussing internally for sure.
EDIT: This however means that we cannot have arviz as an official dependency until we drop support for python 2.
@ColCarroll we plan for Pyro to support Python 2 as long as PyTorch supports Python 2.
@neerajprad Could you point me against the profiling test which is slow? It seems that the file tests/perf/prof_hmc_hmm.py
is not available in dev branch.
@fehiepsi - I updated the example in the prof-hmm
branch. It should be in tests/perf/prof_hmc_hmm.py
. I don't think there are any immediate TODOs for this one, and this is more of an enhancement issue than a perf issue. Some things we can experiment with in the future would be JITing the grad step itself (once PyTorch supports it).
@neerajprad Totally agree! I just do profiling with both jit and nojit (hmm_2.zip). Most of time is spent for computing _potential_energy
and its grad, so the slowness is not related to hcm/nuts.
It is surprised to me that distributions' log_prob
just take 40s in the total 250s to compute trace_log_prob
. Lots of time is spent on sumproduct
(140s) and pack_tensors
(40s) stuffs. I guess this is expected? ~In addition, get_trace
spends a lot of time on process_message
stuffs and post_process_message
stuffs (more than 100s). I believe to get samples for this small model, we just need 1s (in 100s) for sampling.~
It is surprised to me that distributions'
log_prob
just take 40s
In the HMM example the distribution log_prob
computation is merely a gather
, i.e. memcopy; all actual computation is done by sumproduct when combining log_prob
tensors from multiple sites, i.e matmul
and einsum
.
@fritzo That makes sense! So the slowness is expected for models with discrete variables.
I put here a profiling work https://gist.github.com/fehiepsi/75dfbea31b993f165f51524776185be6 for reference.
For the same model, Pyro took 66s while Stan took 33s. The inference results are quite comparable.
But the important point is: it took 32s to compile the Stan model, and only 1s for sampling! So compiling plays an important job here. Hope that PyTorch JIT will be improved in the future. :)
~@fehiepsi the difference may be related to this PyTorch issue about einsum
performance: https://github.com/pytorch/pytorch/issues/10661~ oops, there's no enumeration happening in supervised_hmm
, so this might not be a problem
I'm jumping in here because I've also seen serious performance issues with NUTS/HMC.
In my testing, performance starts out adequate, and then begins to drop precipitously after around 20 iterations. I have observed as this is occurring, the step size is increasing.
To me, the interesting thing is that the performance isn't constant. It declines after time. That suggests to me that the issue is not limited to the time it takes to calculate the energy, which shouldn't vary that much from iteration to iteration.
I have two suspicions: The first is that the HMC/NUTS implementation is trying too hard to increase the step size, and so that it ends up producing lots and lots of divergences. The second is that this has to do with memory fragmentation because of the very large number of tensors that are created as intermediate steps and then retained through gradient calculation.
I believe the slowness is expected when running Pyro MCMC on markov models. So I would like to close this issue. We can point users to numpyro if speed is needed.
Since we started using einsum to evaluate the joint log density of the trace with discrete parameters, we can do discrete site enumeration across many more models in HMC/NUTS without going OOM.
For HMMs, however, NUTS / HMC is extremely slow to the point of being unusable beyond a few time steps, say 10. Refer to this test for profiling.
There are a few issues that I noticed:
_get_trace
) also takes more than 3s. While comparatively small, I believe this can be optimized if we assume our models to be static by assuming a different data structure inside HMC, so that we do not need to run the model each time.