pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.51k stars 983 forks source link

PyTorch 1.10 throws new jit errors #2965

Open fritzo opened 2 years ago

fritzo commented 2 years ago

This issue tracks new errors in the PyTorch 1.10 jit, e.g.

pytest tests/infer/test_jit.py::test_dirichlet_bernoulli -k Jit -vx --runxfail
``` __ test_dirichlet_bernoulli[JitTraceEnum_ELBO-False] __ Elbo = , vectorized = False @pytest.mark.parametrize("vectorized", [False, True]) @pytest.mark.parametrize( "Elbo", [ TraceEnum_ELBO, JitTraceEnum_ELBO, ], ) def test_dirichlet_bernoulli(Elbo, vectorized): pyro.clear_param_store() data = torch.tensor([1.0] * 6 + [0.0] * 4) def model1(data): concentration0 = constant([10.0, 10.0]) f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1] for i in pyro.plate("plate", len(data)): pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i]) def model2(data): concentration0 = constant([10.0, 10.0]) f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1] pyro.sample( "obs", dist.Bernoulli(f).expand_by(data.shape).to_event(1), obs=data ) model = model2 if vectorized else model1 def guide(data): concentration_q = pyro.param( "concentration_q", constant([15.0, 15.0]), constraint=constraints.positive ) pyro.sample("latent_fairness", dist.Dirichlet(concentration_q)) elbo = Elbo( num_particles=7, strict_enumeration_warning=False, ignore_jit_warnings=True ) optim = Adam({"lr": 0.0005, "betas": (0.90, 0.999)}) svi = SVI(model, guide, optim, elbo) for step in range(40): > svi.step(data) tests/infer/test_jit.py:462: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ pyro/infer/svi.py:145: in step loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs) pyro/infer/traceenum_elbo.py:564: in loss_and_grads differentiable_loss = self.differentiable_loss(model, guide, *args, **kwargs) pyro/infer/traceenum_elbo.py:561: in differentiable_loss return self._differentiable_loss(*args, **kwargs) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ self = args = (tensor([1., 1., 1., 1., 1., 1., 0., 0., 0., 0.]),) kwargs = {'_guide_id': 140690819139104, '_model_id': 140690812334288} key = (1, (('_guide_id', 140690819139104), ('_model_id', 140690812334288))) unconstrained_params = [tensor([2.7072, 2.7090], requires_grad=True)] params_and_args = [tensor([2.7072, 2.7090], requires_grad=True), tensor([1., 1., 1., 1., 1., 1., 0., 0., 0., 0.])] param_capture = def __call__(self, *args, **kwargs): key = _hashable_args_kwargs(args, kwargs) # if first time if key not in self.compiled: # param capture with poutine.block(): with poutine.trace(param_only=True) as first_param_capture: self.fn(*args, **kwargs) self._param_names = list(set(first_param_capture.trace.nodes.keys())) unconstrained_params = tuple( pyro.param(name).unconstrained() for name in self._param_names ) params_and_args = unconstrained_params + args weakself = weakref.ref(self) def compiled(*params_and_args): self = weakself() unconstrained_params = params_and_args[: len(self._param_names)] args = params_and_args[len(self._param_names) :] constrained_params = {} for name, unconstrained_param in zip( self._param_names, unconstrained_params ): constrained_param = pyro.param( name ) # assume param has been initialized assert constrained_param.unconstrained() is unconstrained_param constrained_params[name] = constrained_param return poutine.replay(self.fn, params=constrained_params)( *args, **kwargs ) if self.ignore_warnings: compiled = ignore_jit_warnings()(compiled) with pyro.validation_enabled(False): time_compilation = self.jit_options.pop("time_compilation", False) with optional(timed(), time_compilation) as t: self.compiled[key] = torch.jit.trace( compiled, params_and_args, **self.jit_options ) if time_compilation: self.compile_time = t.elapsed else: unconstrained_params = [ # FIXME this does unnecessary transform work pyro.param(name).unconstrained() for name in self._param_names ] params_and_args = unconstrained_params + list(args) with poutine.block(hide=self._param_names): with poutine.trace(param_only=True) as param_capture: > ret = self.compiled[key](*params_and_args) E RuntimeError: The following operation failed in the TorchScript interpreter. E Traceback of TorchScript (most recent call last): E RuntimeError: Unsupported value kind: Tensor pyro/ops/jit.py:121: RuntimeError ```

It looks like some inserted constant tensor is failing insertableTensor by requiring grad. I've spent a couple hours debugging but haven't been able to isolate the error.

Linux-cpp-lisp commented 2 years ago

Can reproduce this with a completely different model and it only happens on PyTorch 1.10.

For me it also only happens when I torch.jit.freeze the model.