HEmile / storchastic

Stochastic Automatic Differentiation library for PyTorch.
GNU General Public License v3.0
178 stars 5 forks source link

Examples with Bernoulli distributions failing #86

Closed csmith49 closed 3 years ago

csmith49 commented 3 years ago

Issue

Running either of the examples with Bernoulli distributions (examples/bernoulli_toy.py and examples/bernoulli_grad_var.py) results in a ValueError, specifically:

ValueError: Input arguments must all be instances of numbers.Number or torch.tensor.

Environment

Running Python 3.8.5, with Pyro 1.5.0 and Torch 1.7.0.

Traceback

Running python3 examples/bernoulli_toy.py from the root of the repo produces (with better_exceptions enabled):

Traceback (most recent call last):
  File "examples/bernoulli_toy.py", line 29, in <module>
    experiment(Expect("x"))
    │          └ <class 'storch.method.method.Expect'>
    └ <function experiment at 0x7fbf47b3a0d0>
  File "examples/bernoulli_toy.py", line 19, in experiment
    x = method(b)
        │      └ Bernoulli(probs: torch.Size([4]), logits: torch.Size([4]))
        └ Expect(
  (sampling_method): Enumerate()
)
  File ".../python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
             │             │        └ {}
             │             └ (Bernoulli(probs: torch.Size([4]), logits: torch.Size([4])),)
             └ Expect(
  (sampling_method): Enumerate()
)
  File ".../python3.8/site-packages/storch/method/method.py", line 52, in forward
    return self.sample(distr)
           │           └ Bernoulli(probs: torch.Size([4]), logits: torch.Size([4]))
           └ Expect(
  (sampling_method): Enumerate()
)
  File ".../python3.8/site-packages/storch/method/method.py", line 114, in sample
    batch_weighting = self.sampling_method.plate_weighting(s_tensor, plate)
                      │                                    │         └ ('x', 16, tensor(0.0625))
                      │                                    └ tensor([[0., 0., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.],
        [0., 0., 1., 1.],
        [0., 1., 0., 0.]...
                      └ Expect(
  (sampling_method): Enumerate()
)
  File ".../python3.8/site-packages/storch/sampling/expect.py", line 82, in plate_weighting
    log_probs = tensor.distribution.log_prob(tensor)
                │                            └ tensor([[0., 0., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.],
        [0., 0., 1., 1.],
        [0., 1., 0., 0.]...
                └ tensor([[0., 0., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.],
        [0., 0., 1., 1.],
        [0., 1., 0., 0.]...
  File ".../python3.8/site-packages/torch/distributions/bernoulli.py", line 94, in log_prob
    logits, value = broadcast_all(self.logits, value)
            │       │             │            └ tensor([[0., 0., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.],
        [0., 0., 1., 1.],
        [0., 1., 0., 0.]...
            │       │             └ Bernoulli(probs: torch.Size([4]), logits: torch.Size([4]))
            │       └ <function broadcast_all at 0x7fbe57163430>
            └ tensor([[0., 0., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.],
        [0., 0., 1., 1.],
        [0., 1., 0., 0.]...
  File ".../python3.8/site-packages/torch/distributions/utils.py", line 24, in broadcast_all
    raise ValueError('Input arguments must all be instances of numbers.Number or torch.tensor.')
ValueError: Input arguments must all be instances of numbers.Number or torch.tensor.
HEmile commented 3 years ago

Thanks for the report! I hadn't run the code with PyTorch 1.7.0 yet, and it indeed didn't work.