HEmile / storchastic

Stochastic Automatic Differentiation library for PyTorch.
GNU General Public License v3.0
180 stars 5 forks source link

RELAX/REBAR missing mc_sample #88

Closed csmith49 closed 3 years ago

csmith49 commented 3 years ago

RELAX and REBAR gradient estimators don't appear to be usable on Bernoulli random variables - some operation expects them to have the mc_sample attribute, which is missing.

I can reproduce this result with the following:

import storch
import torch
from torch.distributions import Bernoulli
from storch.method import RELAX

torch.manual_seed(0)

p = torch.tensor(0.5, requires_grad=True)
d = Bernoulli(p)
sample = RELAX("sample")(d)
storch.add_cost(sample, "cost")
storch.backward()

which produces the following trace:

Traceback (most recent call last):
  File "...", line 13, in <module>
    sample = RELAX("sample")(d)
  File ".../lib/python3.8/site-packages/storch/method/relax.py", line 210, in __init__
    super().__init__(plate_name, sampling_method.set_mc_sample(self.mc_sample))
  File ".../lib/python3.8/site-packages/torch/nn/modules/module.py", line 771, in __getattr__
    raise ModuleAttributeError("'{}' object has no attribute '{}'".format(
torch.nn.modules.module.ModuleAttributeError: 'RELAX' object has no attribute 'mc_sample'
HEmile commented 3 years ago

Hi, I fixed this on the master branch but errors persist. The code you provided runs on the dice branch https://github.com/HEmile/storchastic/tree/dice though. However, RELAX is a rather challenging and bug-prone estimator to implement, and I'm not convinced it's completely correct right now.

HEmile commented 3 years ago

Your provided example runs with sample = RELAX("sample", in_dim=1)(d) on the master branch.