Open mihai-spire opened 4 years ago
To be honest, I'm not totally sure what's going on here because I'm not familiar enough with the exact internals of Pyro.
This might have something to do with the fact that we register priors to transformed versions of the parameters, rather than the parameters directly?
we register priors to transformed versions of the parameters, rather than the parameters directly
Hmm... it seems that the priors we register in this example are to the original (untransformed / constrained) versions of the parameters, right? E.g. we require the lengthscale
to be positive by assigning it a uniform prior over [0.01, 0.5]
. I'm guessing Pyro maps these to an unconstrained space pre-inference, right?
model.mean_module.register_prior("mean_prior", UniformPrior(-1, 1), "constant")
model.covar_module.base_kernel.register_prior("lengthscale_prior", UniformPrior(0.01, 0.5), "lengthscale")
model.covar_module.base_kernel.register_prior("period_length_prior", UniformPrior(0.05, 2.5), "period_length")
model.covar_module.register_prior("outputscale_prior", UniformPrior(1, 2), "outputscale")
likelihood.register_prior("noise_prior", UniformPrior(0.05, 0.3), "noise")
Lengthscale, period_length, output scale, and noise are all "transformed" parameters in GPyTorch. We store an unconstrained version of these parameters (e.g. referred to as _lengthscale
) and then transform it (usually via the soft plus function) to the positive-valued parameter (e.g. referred to as lengthscale
).
The priors are in some sense an additional constraint, but they are being applied to the transformed parameter (e.g. length scale
) rather than the untransformed parameter (e.g. _lengthscale
).
@neerajprad @fehiepsi any idea what might be going on here? should be easy to reproduce (?) using the notebook linked in the issue.
@mihai-spire incidentally if you're interested in speed you might try using jax + numpyro. although there's no built-in support for GPs so it won't be quite as plug-and-play
The issue is that somewhere during JIT tracing, we are inserting a tensor (most likely a constant) with requires_grad=True
, whereas torch.jit
expects all such tensors to be arguments instead. As @mihai-spire pointed out in the Pyro issue, this can happen due to a code path that caches tensors with requires_grad=True
on the first invocation and inserts it later during tracing. Maybe there is some caching that happens within transforms or even earlier (the backtrace should provide some clue)? @jacobrgardner can probably speak to that. If that is indeed the issue, one solution would be to provide an option that disables caching during JIT tracing.
Hi, I have encountered this same issue. Have you figured one way out?
+1. It would be great to get to get some more clarity on the issue here. Using jit should greatly speed things up
This issue is basically superseded by #1578, as the bugs are caused by the same problem, and a fix to #1578 will resolve this issue as well. I've confirmed that jit_compile works with at least the fix I have so far.
I get the following error when I run the notebook with jit_compile=True
~/opt/anaconda3/envs/dev/lib/python3.7/site-packages/pyro/infer/mcmc/util.py in _potential_fn_jit(self, skip_jit_warnings, jit_options, params)
292
293 if self._compiled_fn:
--> 294 return self._compiled_fn(*vals)
295
296 with pyro.validation_enabled(False):
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: Unsupported value kind: Tensor
The example works with jit_compile=False
.
I'm trying to run your fully Bayesian GP example.
The notebook runs OK as-is. As you may expect, sampling is much slower when i increase the size of the training dataset. I've tried to enable
jit
compilation in thepyro
NUTS sampler:After this change, the NUTS sampler crashes:
I've done some googling and found https://github.com/pyro-ppl/pyro/issues/2292 - this seems to indicate that i failed to properly register a prior, perhaps for the
noise_covar.noise
of my Gaussian likelihood? Is this true? In your example, I do see a noise prior being registered, namelyIf so, how do I register the missing prior? Or am I looking at this the wrong way? Thanks!