probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
696 stars 81 forks source link

Update hmm parameters to use dataclasses instead of dicts. #253

Closed gileshd closed 2 years ago

gileshd commented 2 years ago

Param dict --> dataclass

This PR changes hmm parameters to be stored in nested dataclasses rather than dictionaries.

This introduces four different types of dataclass, one for initial params, one for transition params, one for emissions params and the top level class which houses an instance of each of the above. Something like:

@chex.dataclass
class InitialParams:
    probs: Float[Array, "state_dim"]

@chex.dataclass
class TransitionsParams:
    transition_matrix: Float[Array, "state_dim state_dim"]

@chex.dataclass
class EmissionsParams:
    means: Float[Array, "state_dim emission_dim"]
    scales: Float[Array, "state_dim emission_dim"]

@chex.dataclass
class Params:
    initial: InitialParams
    transitions: TransitionsParams
    emissions: EmissionsParams

This PR defines the appropriate dataclasses and updates the relevant model/emission objects to use/return a dataclass rather than a nested dictionary object. The naming convention is to prepend "Params" to the relevant class: Params*, e.g. ParamsPoissonHMMEmissions, ParamsPoissonHMM, ParamsStandardHMMInitialState, ...

The fields of a dataclass can be accessed with either 'square-bracket notation' like a dictionary (e.g. params['probs']) or with 'dot notation' (e.g. params.probs). However only the dot notation can be used if the value of the field is being set. Therefore, where necessary, square-bracket access has been replaced by dot access.

to/from unconstrained

To reflect change to using dataclasses, the to_unconstrained and from_unconstrained functions have been updated and their behaviour has slightly changed.

The function to_unconstrained previously returned two dictionaries, unc_params and fixed_params, containing the unconstrained and fixed parameters respectively. It now returns the unconstrained parameters in a dictionary (as before) alongside the original params dataclass which contains all parameters (not just the fixed parameters).

The function from_unconstrained takes a dictionary of unconstrained parameters (as before) as well as a dataclass containing all of the parameters (orig_params) and for each parameter in unc_params replaces the corresponding leaf in orig_params with the appropriately constrained value from unc_params.

Possible future changes/additions:

At present all the tests in hmm/ are passing in this branch but this change has most likely broken almost all of the hmm notebooks, fixing them is a priority.

Other possible updates include:

gileshd commented 2 years ago

It looks like this has broken some of the lgssm models (two tests failing) I'll try to fix this now.

gileshd commented 2 years ago

lgssm parameters have now been similarly updated to nested dataclasses:

@chex.dataclass
class ParamsLGSSMInitial:
    mean: Float[Array, "state_dim"]
    cov: Float[Array, "state_dim state_dim"]

@chex.dataclass
class ParamsLGSSMDynamics:
    weights: Float[Array, "state_dim state_dim"]
    bias: Float[Array, "state_dim"]
    input_weights: Float[Array, "state_dim input_dim"]
    cov: Float[Array, "state_dim state_dim"]

@chex.dataclass
class ParamsLGSSMEmissions:
    weights: Float[Array, "emission_dim state_dim"]
    bias: Float[Array, "emission_dim"]
    input_weights: Float[Array, "emission_dim input_dim"]
    cov: Float[Array, "emission_dim emission_dim"]

@chex.dataclass
class ParamsLGSSM:
    initial: ParamsLGSSMInitial
    dynamics: ParamsLGSSMDynamics
    emissions: ParamsLGSSMEmissions
slinderman commented 2 years ago

Looks good @gileshd! I'm going to merge this into main and then update it with the Parameter refactor I've been working on.