probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
667 stars 76 forks source link

Non-determinism in SKLearn KMeans initialisation #327

Closed andrewwarrington closed 1 year ago

andrewwarrington commented 1 year ago

The KMeans initialisation used in many models uses the SKLearn implementation. The SKLearn call is not seeded, breaking the determinism of the JAX code. I assume at some point you are going to include a JAX implementation of KMeans, but until then, a simple fix is:

key, subkey = jr.split(key)  # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647)  # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))

The original flaw and solution are demonstrated in this notebook.

Thanks, A

murphyk commented 1 year ago

LGTM.