dfm / tinygp

The tiniest of Gaussian Process libraries
https://tinygp.readthedocs.io
MIT License
294 stars 24 forks source link

SVI section of likelihoods tutorial throws runtime error #181

Closed OliviaLynn closed 1 year ago

OliviaLynn commented 1 year ago

As in title; the end (second to last cell) of the likelihoods notebook is rendering with an error in the documentation.

Output is:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_442/551631826.py in <module>
     18 optim = numpyro.optim.Adam(0.01)
     19 svi = numpyro.infer.SVI(model, guide, optim, numpyro.infer.Trace_ELBO(10))
---> 20 results = svi.run(jax.random.PRNGKey(55873), 3000, x, y=y, progress_bar=False)

~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/infer/svi.py in run(self, rng_key, num_steps, progress_bar, stable_update, init_state, *args, **kwargs)
    340 
    341         if init_state is None:
--> 342             svi_state = self.init(rng_key, *args, **kwargs)
    343         else:
    344             svi_state = init_state

~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/infer/svi.py in init(self, rng_key, *args, **kwargs)
    180         guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
    181         model_trace = trace(replay(model_init, guide_trace)).get_trace(
--> 182             *args, **kwargs, **self.static_kwargs
    183         )
    184         params = {}

~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
    169         :return: `OrderedDict` containing the execution trace.
    170         """
--> 171         self(*args, **kwargs)
    172         return self.trace
    173 

~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
    103             return self
    104         with self:
--> 105             return self.fn(*args, **kwargs)
    106 
    107 

~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
    103             return self
    104         with self:
--> 105             return self.fn(*args, **kwargs)
    106 
    107 

~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
    103             return self
    104         with self:
--> 105             return self.fn(*args, **kwargs)
    106 
    107 

/tmp/ipykernel_442/1231121528.py in model(x, y)
     13 def model(x, y=None):
     14     # The parameters of the GP model
---> 15     mean = numpyro.sample("mean", dist.Normal(0.0, 2.0))
     16     sigma = numpyro.sample("sigma", dist.HalfNormal(3.0))
     17     rho = numpyro.sample("rho", dist.HalfNormal(10.0))

~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/primitives.py in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
    217 
    218     # ...and use apply_stack to send it to the Messengers
--> 219     msg = apply_stack(initial_msg)
    220     return msg["value"]
    221 

~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/primitives.py in apply_stack(msg)
     45     pointer = 0
     46     for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47         handler.process_message(msg)
     48         # When a Messenger sets the "stop" field of a message,
     49         # it prevents any Messengers above it on the stack from being applied.

~/checkouts/readthedocs.org/user_builds/tinygp/envs/v0.1.1/lib/python3.7/site-packages/numpyro/handlers.py in process_message(self, msg)
    221                     return None
    222                 if guide_msg["type"] != "sample" or guide_msg["is_observed"]:
--> 223                     raise RuntimeError(f"Site {name} must be sampled in trace.")
    224                 msg["value"] = guide_msg["value"]
    225                 msg["infer"] = guide_msg["infer"]

RuntimeError: Site mean must be sampled in trace.
dfm commented 1 year ago

Thanks for the report! It looks like you're looking at an older version of the docs so there's not much we can do about that, but the latest and stable versions seem to be fine (e.g. https://tinygp.readthedocs.io/en/stable/tutorials/likelihoods.html) so perhaps those can do the trick for you?

OliviaLynn commented 1 year ago

Oops good call, got there from google and didn't notice the version. Thanks!