```
__ test_dirichlet_bernoulli[JitTraceEnum_ELBO-False] __
Elbo = , vectorized = False
@pytest.mark.parametrize("vectorized", [False, True])
@pytest.mark.parametrize(
"Elbo",
[
TraceEnum_ELBO,
JitTraceEnum_ELBO,
],
)
def test_dirichlet_bernoulli(Elbo, vectorized):
pyro.clear_param_store()
data = torch.tensor([1.0] * 6 + [0.0] * 4)
def model1(data):
concentration0 = constant([10.0, 10.0])
f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1]
for i in pyro.plate("plate", len(data)):
pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])
def model2(data):
concentration0 = constant([10.0, 10.0])
f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1]
pyro.sample(
"obs", dist.Bernoulli(f).expand_by(data.shape).to_event(1), obs=data
)
model = model2 if vectorized else model1
def guide(data):
concentration_q = pyro.param(
"concentration_q", constant([15.0, 15.0]), constraint=constraints.positive
)
pyro.sample("latent_fairness", dist.Dirichlet(concentration_q))
elbo = Elbo(
num_particles=7, strict_enumeration_warning=False, ignore_jit_warnings=True
)
optim = Adam({"lr": 0.0005, "betas": (0.90, 0.999)})
svi = SVI(model, guide, optim, elbo)
for step in range(40):
> svi.step(data)
tests/infer/test_jit.py:462:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyro/infer/svi.py:145: in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
pyro/infer/traceenum_elbo.py:564: in loss_and_grads
differentiable_loss = self.differentiable_loss(model, guide, *args, **kwargs)
pyro/infer/traceenum_elbo.py:561: in differentiable_loss
return self._differentiable_loss(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self =
args = (tensor([1., 1., 1., 1., 1., 1., 0., 0., 0., 0.]),)
kwargs = {'_guide_id': 140690819139104, '_model_id': 140690812334288}
key = (1, (('_guide_id', 140690819139104), ('_model_id', 140690812334288)))
unconstrained_params = [tensor([2.7072, 2.7090], requires_grad=True)]
params_and_args = [tensor([2.7072, 2.7090], requires_grad=True), tensor([1., 1., 1., 1., 1., 1., 0., 0., 0., 0.])]
param_capture =
def __call__(self, *args, **kwargs):
key = _hashable_args_kwargs(args, kwargs)
# if first time
if key not in self.compiled:
# param capture
with poutine.block():
with poutine.trace(param_only=True) as first_param_capture:
self.fn(*args, **kwargs)
self._param_names = list(set(first_param_capture.trace.nodes.keys()))
unconstrained_params = tuple(
pyro.param(name).unconstrained() for name in self._param_names
)
params_and_args = unconstrained_params + args
weakself = weakref.ref(self)
def compiled(*params_and_args):
self = weakself()
unconstrained_params = params_and_args[: len(self._param_names)]
args = params_and_args[len(self._param_names) :]
constrained_params = {}
for name, unconstrained_param in zip(
self._param_names, unconstrained_params
):
constrained_param = pyro.param(
name
) # assume param has been initialized
assert constrained_param.unconstrained() is unconstrained_param
constrained_params[name] = constrained_param
return poutine.replay(self.fn, params=constrained_params)(
*args, **kwargs
)
if self.ignore_warnings:
compiled = ignore_jit_warnings()(compiled)
with pyro.validation_enabled(False):
time_compilation = self.jit_options.pop("time_compilation", False)
with optional(timed(), time_compilation) as t:
self.compiled[key] = torch.jit.trace(
compiled, params_and_args, **self.jit_options
)
if time_compilation:
self.compile_time = t.elapsed
else:
unconstrained_params = [
# FIXME this does unnecessary transform work
pyro.param(name).unconstrained() for name in self._param_names
]
params_and_args = unconstrained_params + list(args)
with poutine.block(hide=self._param_names):
with poutine.trace(param_only=True) as param_capture:
> ret = self.compiled[key](*params_and_args)
E RuntimeError: The following operation failed in the TorchScript interpreter.
E Traceback of TorchScript (most recent call last):
E RuntimeError: Unsupported value kind: Tensor
pyro/ops/jit.py:121: RuntimeError
```
It looks like some inserted constant tensor is failing insertableTensor by requiring grad. I've spent a couple hours debugging but haven't been able to isolate the error.
This issue tracks new errors in the PyTorch 1.10 jit, e.g.
It looks like some inserted constant tensor is failing insertableTensor by requiring grad. I've spent a couple hours debugging but haven't been able to isolate the error.