probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
667 stars 76 forks source link

Support for kmeans initialization with vmap #315

Open ghuckins opened 1 year ago

ghuckins commented 1 year ago

Hi there,

When I try to use vmap to vectorize a function that includes a kmeans initialization, I get the following error:

jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[11396,7])>with<BatchTrace(level=1/0)>

And here's the code that produces the error:

    hmm = GaussianHMM(latdim, obsdim)
    data1 = jnp.array(data1)
    data2 = jnp.array(data2)
    data1_train = jnp.stack([jnp.concatenate([data1[:i], data1[i+1:]]) for i in range(len(data1))])
    data2_train = jnp.stack([jnp.concatenate([data2[:i], data2[i+1:]]) for i in range(len(data2))])

    base_params1, props1 = hmm.initialize(key=get_key(), method="kmeans", emissions=data1[:length,:,:])
    params1, _ = hmm.fit_em(base_params1, props1, data1[:length,:,:], num_iters=100, verbose=False)
    base_params2, props2 = hmm.initialize(key=get_key(), method="kmeans", emissions=data2[:length,:,:])
    params2, _ = hmm.fit_em(base_params2, props2, data2[:length,:,:], num_iters=100, verbose=False)
    def _fit_fold(train, test, params):
        base_params, props = hmm.initialize(key=get_key(), method="kmeans", emissions=train[:length,:,:])
        fit_params, _ = hmm.fit_em(base_params, props, train[:length,:,:], num_iters=100, verbose=False)
        return (hmm.marginal_log_prob(fit_params, test) > hmm.marginal_log_prob(params, test)).astype(int)

    correct1 = jnp.sum(vmap(_fit_fold, in_axes = [0,0,None])(data1_train,data1,params2))

The error traces back to scikit-learn and Kmeans. The problem seems to be that scikit-learn uses numpy functions and not jax functions. Would it be possible to update hmm.initialize so that it could be use in vectorized functions?

Thanks!

gileshd commented 1 year ago

Hi @ghuckins thanks for showing interest in the library! Yes unfortunately the sklearn bits of code won't naturally play nice with lots of jax's tools.

From what I can tell, updating the "kmeans" option in hmm.initialize to use a jax compatible implementation of the kmeans algorithm would involve writing, testing, and maintaining our own jax kmeans implementation which might be outside the scope of this library unfortunately (unless there is a really great demand for it).

Depending on your precise use case there might be some reasonably straightforward work-arounds. For instance, it might be possible to use the sklearn "kmeans" intialization to generate the appropriate initial parameters for your hmms which could then be passed as input into a function which could be vmapped over your data.

I hope that provides some idea of the way forward, if you wanted to share more details about your use case I would be happy to try to give more precise advice.

ghuckins commented 1 year ago

Hey Giles, thanks for the reply! I actually did find a Jax implementation of k-means online and am working on incorporating it into my own codebase; I could share it once I'm done, if that would be helpful. Just let me know!

murphyk commented 1 year ago

Hi @ghuckins . It would be great if you added your jax implementation of kmeans.