Closed gileshd closed 2 years ago
It looks like this has broken some of the lgssm models (two tests failing) I'll try to fix this now.
lgssm parameters have now been similarly updated to nested dataclasses:
@chex.dataclass
class ParamsLGSSMInitial:
mean: Float[Array, "state_dim"]
cov: Float[Array, "state_dim state_dim"]
@chex.dataclass
class ParamsLGSSMDynamics:
weights: Float[Array, "state_dim state_dim"]
bias: Float[Array, "state_dim"]
input_weights: Float[Array, "state_dim input_dim"]
cov: Float[Array, "state_dim state_dim"]
@chex.dataclass
class ParamsLGSSMEmissions:
weights: Float[Array, "emission_dim state_dim"]
bias: Float[Array, "emission_dim"]
input_weights: Float[Array, "emission_dim input_dim"]
cov: Float[Array, "emission_dim emission_dim"]
@chex.dataclass
class ParamsLGSSM:
initial: ParamsLGSSMInitial
dynamics: ParamsLGSSMDynamics
emissions: ParamsLGSSMEmissions
Looks good @gileshd! I'm going to merge this into main and then update it with the Parameter refactor I've been working on.
Param dict --> dataclass
This PR changes hmm parameters to be stored in nested dataclasses rather than dictionaries.
This introduces four different types of dataclass, one for initial params, one for transition params, one for emissions params and the top level class which houses an instance of each of the above. Something like:
This PR defines the appropriate dataclasses and updates the relevant model/emission objects to use/return a dataclass rather than a nested dictionary object. The naming convention is to prepend "Params" to the relevant class:
Params*
, e.g.ParamsPoissonHMMEmissions
,ParamsPoissonHMM
,ParamsStandardHMMInitialState
, ...The fields of a dataclass can be accessed with either 'square-bracket notation' like a dictionary (e.g.
params['probs']
) or with 'dot notation' (e.g.params.probs
). However only the dot notation can be used if the value of the field is being set. Therefore, where necessary, square-bracket access has been replaced by dot access.to/from unconstrained
To reflect change to using dataclasses, the
to_unconstrained
andfrom_unconstrained
functions have been updated and their behaviour has slightly changed.The function
to_unconstrained
previously returned two dictionaries,unc_params
andfixed_params
, containing the unconstrained and fixed parameters respectively. It now returns the unconstrained parameters in a dictionary (as before) alongside the originalparams
dataclass which contains all parameters (not just the fixed parameters).The function
from_unconstrained
takes a dictionary of unconstrained parameters (as before) as well as a dataclass containing all of the parameters (orig_params
) and for each parameter inunc_params
replaces the corresponding leaf inorig_params
with the appropriately constrained value fromunc_params
.Possible future changes/additions:
At present all the tests in
hmm/
are passing in this branch but this change has most likely broken almost all of the hmm notebooks, fixing them is a priority.Other possible updates include:
ParamsLinearRegressionHMM
andParamsLinearAutoregressiveHMM
have identical fields and types and should perhaps be combined.