pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.58k stars 987 forks source link

Bug of dist.Categorical(prob) and sampling #3377

Closed lf464347567 closed 4 months ago

lf464347567 commented 5 months ago

When I use pyro to solve a Bayesian model, I meet a bug of dist.Categorical(prob). The details of code are shown as follows: image the size of c will becoming when loop. The printed messages are are shown as : Iteration 0: Input probabilities tensor: tensor([0.0000, 0.0000, 0.1000, 0.2000, 0.7000]) Sampled index 0: 4 Sample size: torch.Size([])

Iteration 1: Input probabilities tensor: tensor([0.0000, 0.0000, 0.1000, 0.2000, 0.7000]) Sampled index 1: 3 Sample size: torch.Size([])

Iteration 2: Input probabilities tensor: tensor([0.0000, 0.0000, 0.1000, 0.2000, 0.7000]) Sampled index 2: 3 Sample size: torch.Size([])

Iteration 0: Input probabilities tensor: tensor([0.0000, 0.0000, 0.1000, 0.2000, 0.7000]) Sampled index 0: tensor([0, 1, 2, 3, 4]) Sample size: torch.Size([5])

Iteration 1: Input probabilities tensor: tensor([0.0000, 0.0000, 0.1000, 0.2000, 0.7000]) Sampled index 1: tensor([[0], [1], [2], [3], [4]]) Sample size: torch.Size([5, 1])

Iteration 2: Input probabilities tensor: tensor([0.0000, 0.0000, 0.1000, 0.2000, 0.7000]) Sampled index 2: tensor([[[0]],

    [[1]],

    [[2]],

    [[3]],

    [[4]]])

Sample size: torch.Size([5, 1, 1])

Iteration 0: Input probabilities tensor: tensor([0.0000, 0.0000, 0.1000, 0.2000, 0.7000]) Sampled index 0: tensor([0, 1, 2, 3, 4]) Sample size: torch.Size([5])

Iteration 1: Input probabilities tensor: tensor([0.0000, 0.0000, 0.1000, 0.2000, 0.7000]) Sampled index 1: tensor([[0], [1], [2], [3], [4]]) Sample size: torch.Size([5, 1])

Iteration 2: Input probabilities tensor: tensor([0.0000, 0.0000, 0.1000, 0.2000, 0.7000]) Sampled index 2: tensor([[[0]],

    [[1]],

    [[2]],

    [[3]],

    [[4]]])

Sample size: torch.Size([5, 1, 1])

MCMC Samples: {} Why is c.size() dynamic?

fritzo commented 4 months ago

See https://pyro.ai/examples/enumeration.html