probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
667 stars 76 forks source link

Erroneous sqrt in demo notebook #299

Closed andrewwarrington closed 1 year ago

andrewwarrington commented 1 year ago

Hi all,

In the lgssm_learning.ipynb demo here, I think there shouldn't be the sqrt around the variance bounds in plot_predictions:

plt.fill_between(
    jnp.arange(num_timesteps),
    spc * i + smoothed_emissions[:, i] - 2 * jnp.sqrt(smoothed_emissions_std[I]),
    spc * i + smoothed_emissions[:, i] + 2 * jnp.sqrt(smoothed_emissions_std[I]),
    color=ln.get_color(),
    alpha=0.25,
)

given that the variable name is ..._std and the return type in the model looks to be a standard deviation: https://github.com/probml/dynamax/blob/d88050caaf27204866064f5e752d06159ce1bbb4/dynamax/linear_gaussian_ssm/models.py#L262

I think the error is also persistent in the HMC version here.

A

murphyk commented 1 year ago

Thanks. Fixed in https://github.com/probml/dynamax/commit/006e72959c2d12a49dfe1b1ce88a110672a64f1e