pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.49k stars 982 forks source link

Around halfway through SVI for a Dirichlet process mixture model, throws a: RuntimeError: Invalid index in gather at /pytorch/aten/src/TH/generic/THTensorEvenMoreMath.cpp:657 #2108

Open m-k-S opened 4 years ago

m-k-S commented 4 years ago

Issue Description

Hi, I am trying to train a Dirichlet process (Poisson) mixture model in Pyro using the Chinese restaurant process (CRP) formulation [code and data are below]. For inference, I am using SVI. SVI will run fine for ~20-30 seconds on my machine (about 400 iterations), and then throws the following error: RuntimeError: Invalid index in gather at /pytorch/aten/src/TH/generic/THTensorEvenMoreMath.cpp:657. This occurs regardless of the number of iterations I set.

Environment

I am running this in Google Colab, using Python 3.6, pyro-ppl-0.5.1, torch 1.3.0.

Code Snippet

Dataset is found here: http://www.sidc.be/silso/INFO/snytotcsv.php

The code is:

```py import torch !pip install pyro-ppl import numpy as np import pandas as pd import pyro.distributions as dist import pyro from torch.autograd import Variable from pyro.infer.autoguide import * from pyro.optim import Adam from pyro.infer import SVI, TraceEnum_ELBO, TracePredictive, EmpiricalMarginal from torch.distributions import constraints from pyro import poutine df = pd.read_csv('sunspot.csv', sep=';', names=['time', 'sunspot.year'], usecols=[0, 1]) data = torch.tensor(df['sunspot.year'].values, dtype=torch.float32) def crp_model(data): alpha0 = pyro.sample('alpha', dist.Gamma(2, 0.5)) cluster_rates = {} crp_counts = [] for i in range(len(data)): weights = torch.tensor(crp_counts + [alpha0], dtype=torch.float32) weights /= weights.sum() crp_weights = pyro.param("weights_{}".format(i), Variable(weights), constraint=constraints.simplex) # print (crp_weights) zi = pyro.sample("z_{}".format(i), dist.Categorical(crp_weights)) zi = zi.item() if zi >= len(crp_counts): crp_counts.append(1) else: crp_counts[zi] += 1 if zi not in cluster_rates.keys(): cluster_rates[zi] = pyro.sample("lambda_{}".format(zi), dist.Uniform(0, 200)) lambda_i = cluster_rates[zi] pyro.sample("obs_{}".format(i), dist.Poisson(lambda_i), obs=data[i]) def guide(data): a_q = pyro.param('a', torch.tensor(2.0), constraint=constraints.positive) b_q = pyro.param('b', torch.tensor(0.5), constraint=constraints.positive) ub_Unif = pyro.param('ub', torch.tensor(200), constraint=constraints.positive) lb_Unif = pyro.param('lb', torch.tensor(0), constraint=constraints.positive) alpha_q = pyro.sample('alpha', dist.Gamma(a_q, b_q)) cluster_rates_q = {} crp_counts_q = [] for i in range(len(data)): # sample from a CRP weights_q = torch.tensor(crp_counts_q + [alpha_q], dtype=torch.float32) weights_q /= weights_q.sum() # crp_weights_q = pyro.param("weights_{}".format(i), Variable(weights_q), constraint=constraints.simplex) print (weights_q) zi_q = pyro.sample("z_{}".format(i), dist.Categorical(weights_q)) zi_q = zi_q.item() if zi_q >= len(crp_counts_q): crp_counts_q.append(1) else: crp_counts_q[zi_q] += 1 if zi_q not in cluster_rates_q.keys(): cluster_rates_q[zi_q] = pyro.sample("lambda_{}".format(zi_q), dist.Uniform(lb_Unif, ub_Unif)) optim = Adam({"lr": 0.05}) svi = SVI(crp_model, guide, optim, loss=TraceEnum_ELBO(), num_samples=1000) def train(num_iterations): pyro.clear_param_store() for j in range(num_iterations): loss = svi.step(data) if j % 100 == 0: print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data))) train(1000) for name, value in pyro.get_param_store().items(): print(name, pyro.param(name)) ```
fritzo commented 4 years ago

Can you %debug in the following cell, u until you have the culprit tensors in locals, then an inspect the culprit tensors?

m-k-S commented 4 years ago

Sure:

```/usr/local/lib/python3.6/dist-packages/torch/distributions/categorical.py in log_prob(self, value) 114 value, log_pmf = torch.broadcast_tensors(value, self.logits) 115 value = value[..., :1] --> 116 return log_pmf.gather(-1, value).squeeze(-1) 117 118 def entropy(self): RuntimeError: Invalid index in gather at /pytorch/aten/src/TH/generic/THTensorEvenMoreMath.cpp:657 %debug ipdb> value tensor([2]) ipdb> u > /usr/local/lib/python3.6/dist-packages/pyro/distributions/torch.py(28)log_prob() 26 assert logits.size(-1 - value.dim()) == 1 27 return logits.transpose(-1 - value.dim(), -1).squeeze(-1) ---> 28 return super(Categorical, self).log_prob(value) 29 30 def enumerate_support(self, expand=True): ipdb> value tensor(2) ipdb> u > /usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_struct.py(218)compute_log_prob() 216 shapes = self.format_shapes(last_site=site["name"]) 217 raise ValueError("Error while computing log_prob at site '{}':\n{}\n{}" --> 218 .format(name, exc_value, shapes)).with_traceback(traceback) 219 site["unscaled_log_prob"] = log_p 220 log_p = scale_and_mask(log_p, site["scale"], site["mask"]) ipdb> site {'type': 'sample', 'name': 'z_2', 'fn': Categorical(probs: torch.Size([2]), logits: torch.Size([2])), 'is_observed': False, 'args': (), 'kwargs': {}, 'value': tensor(2), 'infer': {'_dim_to_id': {}}, 'scale': 1.0, 'mask': None, 'cond_indep_stack': (), 'done': True, 'stop': False, 'continuation': None} ipdb> u > /usr/local/lib/python3.6/dist-packages/pyro/infer/enum.py(49)get_importance_trace() 47 model_trace = prune_subsample_sites(model_trace) 48 ---> 49 model_trace.compute_log_prob() 50 guide_trace.compute_score_parts() 51 if is_validation_enabled(): ipdb> u > /usr/local/lib/python3.6/dist-packages/pyro/infer/traceenum_elbo.py(271)_get_trace() 269 """ 270 model_trace, guide_trace = get_importance_trace( --> 271 "flat", self.max_plate_nesting, model, guide, *args, **kwargs) 272 273 if is_validation_enabled(): ipdb> u > /usr/local/lib/python3.6/dist-packages/pyro/infer/traceenum_elbo.py(317)_get_traces() 315 q.put(poutine.Trace()) 316 while not q.empty(): --> 317 yield self._get_trace(model, guide, *args, **kwargs) 318 319 def loss(self, model, guide, *args, **kwargs): ipdb> u *** Oldest frame ```

Could it be something to do with tensor([2]) versus tensor(2)? Is this my fault?

fritzo commented 4 years ago

Try pyro.enable_validation(True)?

m-k-S commented 4 years ago

Now I'm getting "The value argument must be within the support" after several dozens of iterations.

``` > /usr/local/lib/python3.6/dist-packages/torch/distributions/distribution.py(253)_validate_sample() 251 252 if not self.support.check(value).all(): --> 253 raise ValueError('The value argument must be within the support') 254 255 def _get_checked_instance(self, cls, _instance=None): ipdb> value tensor(8.3000) ipdb> u > /usr/local/lib/python3.6/dist-packages/torch/distributions/poisson.py(61)log_prob() 59 def log_prob(self, value): 60 if self._validate_args: ---> 61 self._validate_sample(value) 62 rate, value = broadcast_all(self.rate, value) 63 return (rate.log() * value) - rate - (value + 1).lgamma() ipdb> value tensor(8.3000) ipdb> u > /usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_struct.py(218)compute_log_prob() 216 shapes = self.format_shapes(last_site=site["name"]) 217 raise ValueError("Error while computing log_prob at site '{}':\n{}\n{}" --> 218 .format(name, exc_value, shapes)).with_traceback(traceback) 219 site["unscaled_log_prob"] = log_p 220 log_p = scale_and_mask(log_p, site["scale"], site["mask"]) ```
fritzo commented 4 years ago

Hmm, Poisson.support == constraints.nonnegative_integer which does not contain 8.3

(BTW would you mind enclosing long listings in <details> ... </details>?)

fritzo commented 4 years ago

One issue I notice is that you are optimizing Uniform(a,b) in the guide, but the model is fixed to Uniform(0,200); if the guide ever sets b > 200 there will be an error. You could use say Exponential(1/200) in the model instead?

fritzo commented 4 years ago

Also note that the model is not being enumerated, so you might as well use Trace_ELBO.

m-k-S commented 4 years ago

Got it, thanks for all the tips. It's now saying that the value must be within the support for the mixture assignments (no longer from the Poisson distribution):

``` /usr/local/lib/python3.6/dist-packages/torch/distributions/distribution.py(253)_validate_sample() 251 252 if not self.support.check(value).all(): --> 253 raise ValueError('The value argument must be within the support') 254 255 def _get_checked_instance(self, cls, _instance=None): ipdb> u > /usr/local/lib/python3.6/dist-packages/torch/distributions/categorical.py(112)log_prob() 110 def log_prob(self, value): 111 if self._validate_args: --> 112 self._validate_sample(value) 113 value = value.long().unsqueeze(-1) 114 value, log_pmf = torch.broadcast_tensors(value, self.logits) ipdb> value tensor(3)
jinxmirror13 commented 4 years ago

I also have this error, but from the original tutorial running on my local laptop...

Environment

Python 3.7 pyro-ppl==1.3.0 pyro-api==0.1.1 jupyter-client==5.2.4 jupyter-console==6.0.0

Errors

After running: train(n_iter)

...

ValueError: The value argument must be within the support

During handling of the above exception, another exception occurred:

...

ValueError: Error while computing log_prob at site 'obs':
The value argument must be within the support
Trace Shapes:      
 Param Sites:      
Sample Sites:      
    beta dist  19 |
        value  19 |
     log_prob  19 |
  lambda dist  20 |
        value  20 |
     log_prob  20 |
       z dist 320 |
        value 320 |
     log_prob 320 |
     obs dist 320 |
        value 320 |

I have no idea what these errors mean or how to fix them (I am also new to Pyro and DPMMs)...

fritzo commented 4 years ago

@jinxmirror13 the "value argument must be within the support" error is usually due to NANs. You could try inserting assert not torch.isnan(x).any() for various tensors x in your code. I often do that also for the loss returned from svi: assert not math.isna(loss).

gewirtz commented 3 years ago

Hi fritzo, any advice on what to do if hundreds of iterations in, part of a Pyro param is NAN? You say above you insert assertions, but if the assertions fail, what then? I made a post on the Pyro form (https://forum.pyro.ai/t/svi-step-update-only-specified-parameters-minibatch-elbo-scaling/2211) about this issue, but still having it and googling more lead me here.

Thanks!

martinjankowiak commented 3 years ago

@gewirtz

after each svi.step() you can try something like

for name, param in pyro.get_param_store().named_parameters():
    bad = torch.isnan(param).sum().item() + torch.isinf(param).sum().item()
    if bad > 0:
        print(name, param.shape, bad)

to try to identify where nans first appear

fritzo commented 3 years ago

Answered on the forum for better visibility 😄