probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
636 stars 69 forks source link

Add Dirichlet classifier demo and add utility functions. #285

Closed gileshd closed 1 year ago

gileshd commented 1 year ago

Demo

Add a demo of online classification using online linear regression with a Dirichlet pseudo-count approximation.

The demo presents an sklearn style classifier which uses the trick from Milios et al. (2018) to converts a classification problem into a set of regression problems which are then solved in an online-fashion using lgssm_filter (as in Online linear regression using Kalman filtering).

Utils

The demo introduces two new helper utils to the library:

Plotting params

Some nice plotting parameters are added to utils/plotting.py which can be loaded using:

from dynamax.utils.plotting import custom_rcparams_notebook
plt.rcParams.update(custom_rcparams_notebook)

Default params:

default_plotting_params_example

Custom params:

custom_plotting_params_example

There is currently a problem with accessing the right fonts on colab - might have to remove the pretty font option.

make_lgssm_params

Adds a helper function to construct the a ParamsLGSSM object. At present there are two ways to construct an lgssm parameter object:

  1. import all of the different parameter classes (ParamsLGSSM,ParamsLGSSMInitial, ParamsLGSSMDynamics, ParamsLGSSMEmissions) and manually construct the nested combination.
  2. use the initalize method from an LinearGaussianSSM object.

The first of these options feels a bit unwieldy and requires more work on the part of the user. In the demo the functional interface to the inference code is used inside a custom classifier class and so going through the process of creating an instance of the complex model object just to make building a parameter object feels like overkill.

It is for these reasons that I added a helper function to enable to user to build a ParamsLGSSM object by specifying each parameter as arguments to make_lgssm_params with the format [initial|dynamics|emissions]_[weights|bias|etc...] e.g. the initial_mean argument will set the value of params.initial.mean and dynamics_cov will set params.dynamics.cov. The biases and input weights are defaulted to None which will return zeros of the appropriate shape although this might be uneccessary in light of the @preprocess_args decorator.

Options

Ideally the utils and demo would be separate commits but I wasn't very hygienic with my commits so they have become a bit entangled. I am happy to separate them out though if we don't want to include either of the utility additions.

review-notebook-app[bot] commented 1 year ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

murphyk commented 1 year ago

This looks great. However, it would be useful to add a comment that explains

 emission_matrix = jnp.stack([jnp.kron(jnp.eye(input_dim), row) for row in X_plus_bias])

I assume the state space of the model is the (D,C) weight matrix W flattened into a D C vector w, and the input at time is x_t in R^D, and you want to compute logits = W x_t using state z_t_t=w and emission matrix H_t. It was not entirely clear how this book keeping works...

murphyk commented 1 year ago

It might make sense to factor out CMGFEstimator into a separate file (call it cmgf_logreg_estimator.py) and add to ggsm/demos directory. This can be used by peter's notebook, too, and avoids code duplication.

gileshd commented 1 year ago

I think moving CMGFEstimator, in its current form, to a .py file would require adding sklearn as a dependency which is probably something we want to avoid.

Or at least we would have to be clear that the demos/ directories aren’t covered by our requirements file. Currently it seems to be clear that notebooks can have extra dependencies but not necessarily .py files.

slinderman commented 1 year ago

As long as it’s not imported when you load dynamics, you can put dependencies in there. We can make it a notebook dependency in setup.cfg. Sklearn is also a (temporary) dependency due to Kmeans.

Best, Scott

On Nov 18, 2022, at 3:29 PM, Giles Harper-Donnelly @.***> wrote:

 I think moving CMGFEstimator, in its current form, to a .py file would require adding sklearn as a dependency which is probably something we want to avoid.

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you are subscribed to this thread.

gileshd commented 1 year ago

Ah great ok, I had forgotten that we include a list of notebook dependencies. Just wanted to make sure that people don't have to go about searching for libraries when they run the demos.