cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.54k stars 557 forks source link

Bayesian GPs with pyro (NUTS) - example notebook crashes when jit_compile=True #1286

Open mihai-spire opened 4 years ago

mihai-spire commented 4 years ago

I'm trying to run your fully Bayesian GP example.

The notebook runs OK as-is. As you may expect, sampling is much slower when i increase the size of the training dataset. I've tried to enable jit compilation in the pyro NUTS sampler:

nuts_kernel = NUTS(pyro_model, adapt_step_size=True, jit_compile=True)

After this change, the NUTS sampler crashes:

Warmup:   0%|          | 0/300 [00:00, ?it/s]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-4-56cc00113944> in <module>()
     26 nuts_kernel = NUTS(pyro_model, adapt_step_size=True, jit_compile=True)
     27 mcmc_run = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps, disable_progbar=smoke_test)
---> 28 mcmc_run.run(train_x, train_y)

24 frames
/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      9 def _context_wrap(context, fn, *args, **kwargs):
     10     with context:
---> 11         return fn(*args, **kwargs)
     12 
     13 

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
    378         with optional(pyro.validation_enabled(not self.disable_validation),
    379                       self.disable_validation is not None):
--> 380             for x, chain_id in self.sampler.run(*args, **kwargs):
    381                 if num_samples[chain_id] == 0:
    382                     num_samples[chain_id] += 1

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
    167             for sample in _gen_samples(self.kernel, self.warmup_steps, self.num_samples, hook_w_logging,
    168                                        i if self.num_chains > 1 else None,
--> 169                                        *args, **kwargs):
    170                 yield sample, i  # sample, chain_id
    171             self.kernel.cleanup()

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/api.py in _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs)
    109 
    110 def _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs):
--> 111     kernel.setup(warmup_steps, *args, **kwargs)
    112     params = kernel.initial_params
    113     # yield structure (key, value.shape) of params

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/hmc.py in setup(self, warmup_steps, *args, **kwargs)
    304         if self.initial_params:
    305             z = {k: v.detach() for k, v in self.initial_params.items()}
--> 306             z_grads, potential_energy = potential_grad(self.potential_fn, z)
    307         else:
    308             z_grads, potential_energy = {}, self.potential_fn(self.initial_params)

/usr/local/lib/python3.6/dist-packages/pyro/ops/integrator.py in potential_grad(potential_fn, z)
     80             return grads, z_nodes[0].new_tensor(float('nan'))
     81         else:
---> 82             raise e
     83 
     84     grads = grad(potential_energy, z_nodes)

/usr/local/lib/python3.6/dist-packages/pyro/ops/integrator.py in potential_grad(potential_fn, z)
     73         node.requires_grad_(True)
     74     try:
---> 75         potential_energy = potential_fn(z)
     76     # deal with singular matrices
     77     except RuntimeError as e:

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/util.py in _potential_fn_jit(self, skip_jit_warnings, jit_options, params)
    287             if skip_jit_warnings:
    288                 _pe_jit = ignore_jit_warnings()(_pe_jit)
--> 289             self._compiled_fn = torch.jit.trace(_pe_jit, vals, **jit_options)
    290 
    291             result = self._compiled_fn(*vals)

/usr/local/lib/python3.6/dist-packages/torch/jit/__init__.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    978                                                   var_lookup_fn,
    979                                                   strict,
--> 980                                                   _force_outplace)
    981 
    982     # Check the trace against new traces created from user-specified inputs

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/util.py in _pe_jit(*zi)
    283             def _pe_jit(*zi):
    284                 params = dict(zip(names, zi))
--> 285                 return self._potential_fn(params)
    286 
    287             if skip_jit_warnings:

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/util.py in _potential_fn(self, params)
    259         cond_model = poutine.condition(self.model, params_constrained)
    260         model_trace = poutine.trace(cond_model).get_trace(*self.model_args,
--> 261                                                           **self.model_kwargs)
    262         log_joint = self.trace_prob_evaluator.log_prob(model_trace)
    263         for name, t in self.transforms.items():

/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
    185         Calls this poutine and returns its trace instead of the function's return value.
    186         """
--> 187         self(*args, **kwargs)
    188         return self.msngr.get_trace()

/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    169                 exc = exc_type(u"{}\n{}".format(exc_value, shapes))
    170                 exc = exc.with_traceback(traceback)
--> 171                 raise exc from None
    172             self.msngr.trace.add_node("_RETURN", name="_RETURN", type="return", value=ret)
    173         return ret

/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    163                                       args=args, kwargs=kwargs)
    164             try:
--> 165                 ret = self.fn(*args, **kwargs)
    166             except (ValueError, RuntimeError):
    167                 exc_type, exc_value, traceback = sys.exc_info()

/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      9 def _context_wrap(context, fn, *args, **kwargs):
     10     with context:
---> 11         return fn(*args, **kwargs)
     12 
     13 

/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      9 def _context_wrap(context, fn, *args, **kwargs):
     10     with context:
---> 11         return fn(*args, **kwargs)
     12 
     13 

/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      9 def _context_wrap(context, fn, *args, **kwargs):
     10     with context:
---> 11         return fn(*args, **kwargs)
     12 
     13 

<ipython-input-4-56cc00113944> in pyro_model(x, y)
     19 
     20 def pyro_model(x, y):
---> 21   model.pyro_sample_from_prior()
     22   output = model(x)
     23   loss = mll.pyro_factor(output, y)

/usr/local/lib/python3.6/dist-packages/gpytorch/module.py in pyro_sample_from_prior(self)
    318         parameters of the model that have GPyTorch priors registered to them.
    319         """
--> 320         return _pyro_sample_from_prior(module=self, memo=None, prefix="")
    321 
    322     def local_load_samples(self, samples_dict, memo, prefix):

/usr/local/lib/python3.6/dist-packages/gpytorch/module.py in _pyro_sample_from_prior(module, memo, prefix)
    427     for mname, module_ in module.named_children():
    428         submodule_prefix = prefix + ("." if prefix else "") + mname
--> 429         _pyro_sample_from_prior(module=module_, memo=memo, prefix=submodule_prefix)
    430 
    431 

/usr/local/lib/python3.6/dist-packages/gpytorch/module.py in _pyro_sample_from_prior(module, memo, prefix)
    421                     )
    422                 memo.add(prior)
--> 423                 prior = prior.expand(closure().shape)
    424                 value = pyro.sample(prefix + ("." if prefix else "") + prior_name, prior)
    425                 setting_closure(value)

/usr/local/lib/python3.6/dist-packages/gpytorch/module.py in closure()
    226 
    227             def closure():
--> 228                 return getattr(self, param_or_closure)
    229 
    230             if setting_closure is not None:

/usr/local/lib/python3.6/dist-packages/gpytorch/likelihoods/gaussian_likelihood.py in noise(self)
     83     @property
     84     def noise(self) -> Tensor:
---> 85         return self.noise_covar.noise
     86 
     87     @noise.setter

/usr/local/lib/python3.6/dist-packages/gpytorch/likelihoods/noise_models.py in noise(self)
     33     @property
     34     def noise(self):
---> 35         return self.raw_noise_constraint.transform(self.raw_noise)
     36 
     37     @noise.setter

/usr/local/lib/python3.6/dist-packages/gpytorch/constraints/constraints.py in transform(self, tensor)
    174 
    175     def transform(self, tensor):
--> 176         transformed_tensor = self._transform(tensor) if self.enforced else tensor
    177         return transformed_tensor
    178 

RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
Tensor:
-1.2059
[ torch.FloatTensor{1} ]
Trace Shapes:
 Param Sites:
Sample Sites:

I've done some googling and found https://github.com/pyro-ppl/pyro/issues/2292 - this seems to indicate that i failed to properly register a prior, perhaps for the noise_covar.noise of my Gaussian likelihood? Is this true? In your example, I do see a noise prior being registered, namely

likelihood.register_prior("noise_prior", UniformPrior(0.05, 0.3), "noise")

If so, how do I register the missing prior? Or am I looking at this the wrong way? Thanks!

jacobrgardner commented 4 years ago

To be honest, I'm not totally sure what's going on here because I'm not familiar enough with the exact internals of Pyro.

This might have something to do with the fact that we register priors to transformed versions of the parameters, rather than the parameters directly?

mihai-spire commented 4 years ago

we register priors to transformed versions of the parameters, rather than the parameters directly

Hmm... it seems that the priors we register in this example are to the original (untransformed / constrained) versions of the parameters, right? E.g. we require the lengthscale to be positive by assigning it a uniform prior over [0.01, 0.5]. I'm guessing Pyro maps these to an unconstrained space pre-inference, right?

model.mean_module.register_prior("mean_prior", UniformPrior(-1, 1), "constant")
model.covar_module.base_kernel.register_prior("lengthscale_prior", UniformPrior(0.01, 0.5), "lengthscale")
model.covar_module.base_kernel.register_prior("period_length_prior", UniformPrior(0.05, 2.5), "period_length")
model.covar_module.register_prior("outputscale_prior", UniformPrior(1, 2), "outputscale")
likelihood.register_prior("noise_prior", UniformPrior(0.05, 0.3), "noise")
gpleiss commented 4 years ago

Lengthscale, period_length, output scale, and noise are all "transformed" parameters in GPyTorch. We store an unconstrained version of these parameters (e.g. referred to as _lengthscale) and then transform it (usually via the soft plus function) to the positive-valued parameter (e.g. referred to as lengthscale).

The priors are in some sense an additional constraint, but they are being applied to the transformed parameter (e.g. length scale) rather than the untransformed parameter (e.g. _lengthscale).

martinjankowiak commented 4 years ago

@neerajprad @fehiepsi any idea what might be going on here? should be easy to reproduce (?) using the notebook linked in the issue.

martinjankowiak commented 4 years ago

@mihai-spire incidentally if you're interested in speed you might try using jax + numpyro. although there's no built-in support for GPs so it won't be quite as plug-and-play

neerajprad commented 4 years ago

The issue is that somewhere during JIT tracing, we are inserting a tensor (most likely a constant) with requires_grad=True, whereas torch.jit expects all such tensors to be arguments instead. As @mihai-spire pointed out in the Pyro issue, this can happen due to a code path that caches tensors with requires_grad=True on the first invocation and inserts it later during tracing. Maybe there is some caching that happens within transforms or even earlier (the backtrace should provide some clue)? @jacobrgardner can probably speak to that. If that is indeed the issue, one solution would be to provide an option that disables caching during JIT tracing.

yahoochen97 commented 3 years ago

Hi, I have encountered this same issue. Have you figured one way out?

sdaulton commented 3 years ago

+1. It would be great to get to get some more clarity on the issue here. Using jit should greatly speed things up

jacobrgardner commented 3 years ago

This issue is basically superseded by #1578, as the bugs are caused by the same problem, and a fix to #1578 will resolve this issue as well. I've confirmed that jit_compile works with at least the fix I have so far.

syerramilli commented 2 years ago

I get the following error when I run the notebook with jit_compile=True

~/opt/anaconda3/envs/dev/lib/python3.7/site-packages/pyro/infer/mcmc/util.py in _potential_fn_jit(self, skip_jit_warnings, jit_options, params)
    292 
    293         if self._compiled_fn:
--> 294             return self._compiled_fn(*vals)
    295 
    296         with pyro.validation_enabled(False):

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: Unsupported value kind: Tensor

The example works with jit_compile=False.