pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.55k stars 986 forks source link

[bug] bart.py example consistently fails with cholesky error with default arguments #3001

Closed abeppu closed 2 years ago

abeppu commented 2 years ago

I'm just attempting to get started using pyro, and was trying to work through examples when I found that the example listed on the example page under "Multivariate Forecasting" reliably fails for me.

I see that there is a related open issue which is quite old https://github.com/pyro-ppl/pyro/issues/2017 which has some suggestions for the broader problem which seemed promising, but the PR following the approach discussed (https://github.com/pyro-ppl/pyro/pull/2019) was never approved.

It's not impossible that this is due to some environmental factor. However, I created a clean virtualenv to explore pyro. If there's some environmental contributing factor which I am not aware of, please document it, or even better, add a helper method to health-check a given environment.

If you can confirm/reproduce the failure, I would respectfully suggest that either

Issue Description

examples/contrib/forecast/bart.py fails with cholesky error when run with default params (no args)

Note that this example attempts to use the backtest method, which trains a model several times over different time windows. The first several such windows succeed.

The error appears as follows:

Traceback (most recent call last):
  File "examples/contrib/forecast/bart.py", line 180, in <module>
    main(args)
  File "examples/contrib/forecast/bart.py", line 156, in main
    forecaster_options=forecaster_options,
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/contrib/forecast/evaluate.py", line 205, in backtest
    batch_size=batch_size,
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/contrib/forecast/forecaster.py", line 361, in __call__
    return super().__call__(data, covariates, num_samples, batch_size)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/contrib/forecast/forecaster.py", line 390, in forward
    return self.model(data, covariates)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/nn/module.py", line 426, in __call__
    return super().__call__(*args, **kwargs)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/contrib/forecast/forecaster.py", line 185, in forward
    self.model(zero_data, covariates)
  File "examples/contrib/forecast/bart.py", line 121, in model
    self.predict(noise_model, prediction)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/contrib/forecast/forecaster.py", line 157, in predict
    noise = pyro.sample("residual", noise_dist)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/primitives.py", line 163, in sample
    apply_stack(msg)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 218, in apply_stack
    default_process_message(msg)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 179, in default_process_message
    msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/distributions/torch_distribution.py", line 49, in __call__
    if self.has_rsample
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/distributions/hmm.py", line 584, in rsample
    z = _sequential_gaussian_filter_sample(self._init, trans, sample_shape)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/distributions/hmm.py", line 144, in _sequential_gaussian_filter_sample
    contracted = joint.marginalize(left=state_dim)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/ops/gaussian.py", line 244, in marginalize
    P_b = cholesky(P_bb)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/ops/tensor_utils.py", line 399, in cholesky
    return torch.linalg.cholesky(x)
RuntimeError: torch.linalg.cholesky: (Batch element 255): The factorization could not be completed because the input is not positive-definite (the leading minor of order 2 is not positive-definite).

Environment

value
OS macOS Big Sur (11.5.2) (intell)
python version 3.7.9
pytorch version 1.10.1
pyro version 1.8.0

Note, I get the same behavior on linux in docker.

Code Snippet

copy-pasted the example here: https://pyro.ai/examples/forecast_simple.html / https://github.com/pyro-ppl/pyro/blob/dev/examples/contrib/forecast/bart.py and simply ran:

python bart.py
martinjankowiak commented 2 years ago

hello @abeppu thanks for the detailed report. you can avoid this error by using a smaller learning rate: python bart.py -lr 0.01 i put up a PR #3002 to use a more conservative default learning rate.

generally speaking, these kinds of numerical issues are par for the course for complex optimization problems. we do not run optimization to completion in continuous integration, since doing so would be prohibitively expensive. in practice training any complex model will require the user to explore optimization hyperparameters to find the right setting for the problem at hand