probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
636 stars 69 forks source link

Implementation: specification of filtered means, covariances, and log-likelihood for EKF and UKF #289

Closed gerdm closed 1 year ago

gerdm commented 1 year ago

Closes #288

Added output parameter for extended_kalman_filter and unscented_kalman_filter function. The output param specifies which filtered values to store in memory for output. By default, output param is set to None, which means we return the default fitered_means and filtered_covariances so as to not make any breaking changes to their smoothing functions.

Additionally, output accepts the term marginal_loglik, which returns the history of the marginal log-likelhood. The use of this parameter is tested in ekf_ukf_spiral.ipynb, which shows the marginal log-likelihood of EKF vs UKF with two different choice of hyperparameters.

image

image


Code was tested on


If style is correct, I propose to do the same refactoring for EKS and UKS. Also, refactor for storage of intermediate predictions (see #288).

review-notebook-app[bot] commented 1 year ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

gerdm commented 1 year ago

Failed tests is due to list[str]. However, I'm following this convention.

Which minimum Python version should we consider for this?

slinderman commented 1 year ago

Which minimum Python version should we consider for this?

We are using 3.7 for Colab compatibility. What version is necessary for subscripted types?

gerdm commented 1 year ago

Which minimum Python version should we consider for this?

We are using 3.7 for Colab compatibility. What version is necessary for subscripted types?

@slinderman , generic types were introduced in Python 3.9 and proposed in PEP585

In type annotations you can now use built-in collection types such as list and dict as generic types instead of importing the corresponding capitalized types (e.g. List or Dict) from typing. Some other types in the standard library are also now generic, for example queue.Queue.

slinderman commented 1 year ago

Following up on this from my computer this time. (Sorry for the abbreviated comments earlier!). In an attempt to synthesize the suggestions above, how about this proposal:

  1. Add predicted means and covariances to the filtered posterior, and make some fields optional.

    class PosteriorGSSMFiltered(NamedTuple):
    # ...
    marginal_loglik: Scalar
    filtered_means: Optional[Float[Array, "ntime state_dim"]] = None
    filtered_covariances: Optional[Float[Array, "ntime state_dim state_dim"]] = None
    predicted_means: Optional[Float[Array, "ntime state_dim"]] = None
    predicted_covariances: Optional[Float[Array, "ntime state_dim state_dim"]] = None
  2. Allow user to specify which outputs to compute in the forward pass. I personally like the string list from @gerdm's original suggestion, but without list[str] types to maintain 3.7 compatibility. I would propose:

    def lgssm_filter(
    params: ParamsLGSSM,
    emissions:  Float[Array, "ntime emission_dim"],
    inputs: Optional[Float[Array, "ntime input_dim"]]=None,
    output_fields: List[str]=['filtered_means', 'filtered_covariances', 'predicted_means', 'predicted_covariances'], 
    ) -> PosteriorGSSMFiltered:

    Then update the filter function as follows:

        # ... condition/predict steps
    
        # Make the carry and outputs
        carry = (ll, pred_mean, pred_cov)
        outputs = dict(filtered_means=filtered_mean,
                       filtered_covs=filtered_cov,
                       predicted_means=pred_mean,
                       predicted_covs=pred_cov)
        outputs = dict([(k, v) for k, v in outputs.items() if k in output_fields])
        return carry, outputs
    
    # Run the Kalman filter
    carry = (0.0, params.initial.mean, params.initial.cov)
    (ll, _, _), outputs = lax.scan(_step, carry, jnp.arange(num_timesteps))
    return PosteriorGSSMFiltered(marginal_loglik=ll, **outputs)

    In this formulation, marginal_loglik is special since it is always included in the PosteriorGSSMFiltered object.

I could see an argument for returning an array of partial marginal log likelihoods $[p(y1), p(y{1:2}), \ldots, p(y_{1:T})]$, but I'm afraid that would break existing code if we changed marginal_loglik to output an array rather than a scalar. Alternatively, I could be convinced to add marginal_loglik to our list of output_fields and change

PosteriorGSSMFiltered(marginal_loglik=ll if "marginal_loglik" in output_fields else None, **outputs)

What do you guys think? Last but not least, we should make clear in the docstring that the function will not be jittable with a list of strings as input unless you mark output_fields as a static argument.

murphyk commented 1 year ago

Comments inline

On Mon, Nov 28, 2022 at 8:13 PM Scott Linderman @.***> wrote:

Following up on this from my computer this time. (Sorry for the abbreviated comments earlier!). In an attempt to synthesize the suggestions above, how about this proposal:

  1. Add predicted means and covariances to the filtered posterior, and make some fields optional.

class PosteriorGSSMFiltered(NamedTuple):

...

marginal_loglik: Scalar
filtered_means: Optional[Float[Array, "ntime state_dim"]] = None
filtered_covariances: Optional[Float[Array, "ntime state_dim state_dim"]] = None
predicted_means: Optional[Float[Array, "ntime state_dim"]] = None
predicted_covariances: Optional[Float[Array, "ntime state_dim state_dim"]] = None
  1. Allow user to specify which outputs to compute in the forward pass. I personally like the string list from @gerdm https://github.com/gerdm's original suggestion, but without list[str] types to maintain 3.7 compatibility. I would propose:

def lgssm_filter( params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]]=None, output_fields: List[str]=['filtered_means', 'filtered_covariances', 'predicted_means', 'predicted_covariances'], ) -> PosteriorGSSMFiltered:

Then update the filter function as follows:

    # ... condition/predict steps

    # Make the carry and outputs
    carry = (ll, pred_mean, pred_cov)
    outputs = dict(filtered_means=filtered_mean,
                   filtered_covs=filtered_cov,
                   predicted_means=pred_mean,
                   predicted_covs=pred_cov)
    outputs = dict([(k, v) for k, v in outputs.items() if k in output_fields])
    return carry, outputs

# Run the Kalman filter
carry = (0.0, params.initial.mean, params.initial.cov)
(ll, _, _), outputs = lax.scan(_step, carry, jnp.arange(num_timesteps))
return PosteriorGSSMFiltered(marginal_loglik=ll, **outputs)

In this formulation, marginal_loglik is special since it is always included in the PosteriorGSSMFiltered object.

I could see an argument for returning an array of partial marginal log likelihoods $[p(y1), p(y{1:2}), \ldots, p(y_{1:T})]$, but I'm afraid that would break existing code if we changed marginal_loglik to output an array rather than a scalar. Alternatively, I could be convinced to add marginal_loglik to our list of output_fields and change

PosteriorGSSMFiltered(marginal_loglik=ll if "marginal_loglik" in output_fields else None, **outputs)

What do you guys think?

Love it

Last but not least, we should make clear in the docstring that the function will not be jittable with a list of strings as input unless you mark output_fields as a static argument.

Yes let’s make it static

Reply to this email directly, view it on GitHub https://github.com/probml/dynamax/pull/289#issuecomment-1330052155, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABDK6EH6NEZKA76VNOECWPTWKV7FLANCNFSM6AAAAAASMRDD5E . You are receiving this because you commented.Message ID: @.***>

-- Sent from Gmail Mobile

gerdm commented 1 year ago

Re: 2.,

def lgssm_filter(
    params: ParamsLGSSM,
    emissions:  Float[Array, "ntime emission_dim"],
    inputs: Optional[Float[Array, "ntime input_dim"]]=None,
    output_fields: List[str]=['filtered_means', 'filtered_covariances', 'predicted_means', 'predicted_covariances'], 
) -> PosteriorGSSMFiltered:

I propose assigning the default value of output_fields to be None rather than a list, as the latter is not recommended. We could use output_fields: ...=None and then set output_fields = ["filtered_means", ...] unless the user passes a list with the desired output values:

if output_fields is None:
    output_fields = ["filtered_means", ...]

Re: 3., "returning an array of partial marginal log likelihoods", Perhaps you meant something like

PosteriorGSSMFiltered(marginal_loglik=ll if "marginal_loglik" not in output_fields else None, **outputs)

to obtain the final marginal log-likelihood if "marginal_loglik" is not specified, otherwise, we have None and it get's rewritten with **outputs. But if this is the case, we can also write

outputs = {"marginal_loglik": ll, **outputs}
PosteriorGSSMFiltered(**outputs)

and marginal_loglik gets rewritten to the history of values only if it's specified. Is this what you meant?

gerdm commented 1 year ago

See latest comments for last proposal change.