tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.24k stars 1.09k forks source link

Input initial_level_prior inconsistent with initial_state_prior (sts.LocalLevel) #1481

Open howardya opened 2 years ago

howardya commented 2 years ago

This might turn out to clarification, but feel free to let me know and I can close it.

In the simplest cases of LocalLevel LocalLinearTrend

the highlighted lines above define initial_state_prior through a MultivariateNormalDiag with mean and scale defined by initial_level_prior or initial_slope_prior.

May I know why are the priors defined as MultivariateNormalDiag even though the initial_level_prior/initial_scale_prior inputs may not be Normal?

For example, suppose I have the following simple random walk model, where my initial_level_prior is a LogNormal distribution.

trend = sts.LocalLevel(
    observed_time_series = co2_by_month_training_data,
    initial_level_prior=tfd.LogNormal(
        loc=2.,
        scale=1.
    )
)

If I check the initial_state_prior, as expected, the prior has been redefined as MultiNormal (which in this is simply a normal distribution)

trend.initial_state_prior

<tfp.distributions.MultivariateNormalDiag 'MultivariateNormalDiag' batch_shape=[] event_shape=[1] dtype=float32>

My questions are:

  1. What is the rationale for overwriting the initial_state_prior with MultivariateNormalDiag as opposed to a joint independent distribution?
  2. Is there a way for me to enforce the initial_state_prior to user defined distribution?

Thank you.

junpenglao commented 2 years ago

Per the formulation of linear Gaussian State space model, latent states (X in the figure below) need to be Gaussian, otherwise you cannot use Kalman filter to do compute log_prob and do state update etc.

My understanding is that priors defined through tfp.sts.components APIs are in general the blue part (which specified parameters that need to inferred through some algorithm other than Kalman filter) (this is from Martin, O. A., Kumar, R., & Lao, J. (2021). Bayesian Modeling and Computation in Python): fig18_bsts_lgssm

I think maybe restricting the initial_level_prior in LocalTrend to tfd.Normal could make things clearer in these 2 cases, but there could be design choice or other places using this information that I am not aware of. @davmre probably could provide a bit more information.

howardya commented 2 years ago

@junpenglao Thank you for your insight. Now I understand a little more. The variables _initial_state_prior or self.initial_state_prior are indeed meant for state space model. They are not used for VI.

For instance, in the example in the Jupiter, initial_state_prior is not used at all. The function

variational_posteriors = tfp.sts.build_factored_surrogate_posterior(
    model=co2_model)

build the prior using co2_model._joint_prior_distribution which is indeed using the initial_prior defined by users.

In summary: For VI and non-State Space Model analysis, the priors are consistent with user defined initial_level_prior/initial_scale_prior. For SSM, the latent state models are using MultivariateNormalDiag.

Appreciate if someone can confirm my above understanding and if this will be changed in the future.

junpenglao commented 2 years ago

For VI and non-State Space Model analysis, the priors are consistent with user defined initial_level_prior/initial_scale_prior. For SSM, the latent state models are using MultivariateNormalDiag.

If by "non-State Space Model analysis" you meant inferring the hyper-parameter (the most obvious one being the sigma of the SSM) of a SSM, then yes.