pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.55k stars 987 forks source link

[Bug] Dimension augmenting with Predictive in Africa Tutorial #3075

Open maulberto3 opened 2 years ago

maulberto3 commented 2 years ago

Issue Description

In Jupyter, hitting several times the following code

pred = infer.Prpred = infer.Predictive(model, guide=auto_guide, num_samples=2)
svi_samples = pred(is_cont_africa, ruggedness, log_gdp)
log_gdp = svi_samples['obs']
print(log_gdp.shape)

The shape goes crazy, i.e.

torch.Size([2, 170])
torch.Size([2, 2, 170])
torch.Size([2, 2, 2, 170])
torch.Size([2, 2, 2, 2, 170])...

Environment

Code Snippet

Already provided above.

Hope to help some.

maulberto3 commented 2 years ago

Found the error. The variable log_gdp was being referenced with each cell. Solution: Just rename it.

Working example:

pred = infer.Predictive(model, guide=auto_guide, num_samples=20)
svi_samples = pred(is_cont_africa, ruggedness, log_gdp)
log_gdp_ = svi_samples['obs']
log_gdp_.shape

Still hope to help.