chongxi / spiketag

Next generation of spike sorting package for BMI
BSD 3-Clause "New" or "Revised" License
6 stars 4 forks source link

Background automatic re-sorting on other core #19

Open chongxi opened 5 years ago

chongxi commented 5 years ago

First version that allow one backend https://github.com/chongxi/spiketag/commit/1e6a1c57f1c5e8f8545369df1e08f95c4e8d1747

Second version tha allows several backend (two backend test) https://github.com/chongxi/spiketag/commit/98949e3a96204b88fafa75e86fc8f495fad7e45c

for fet.toclu():

self.backend.append(cluster(self.clu_status))
self.backend[-1].fit(method, self.fet[group_id], self.clu[group_id], **kwargs)
chongxi commented 5 years ago

ipyparallel nonblocking method is used, here is my test code, which is applied in https://github.com/chongxi/spiketag/commit/1e6a1c57f1c5e8f8545369df1e08f95c4e8d1747 as fet.backend and later I extend the backend to a list of this to support multiple non-blocking instance.

class cluster():
    def __init__(self):
        self.client = ipp.Client()
        self.cpu = client.load_balanced_view()
        self.clu_func = {'hdbscan': self._hdbscan,
                         'dpgmm':   self._dpgmm }

    def fit(self, clu_method, fet, clu, **kwargs):
        self.fet = fet
        self.clu = clu
        func = self.clu_func[clu_method]
        print(func)
        ar = self.cpu.apply_async(func, fet=fet, **kwargs)
        def get_result(ar):
            self.clu.fill(ar.get())
        ar.add_done_callback(get_result)

    @staticmethod
    def _dpgmm(fet, n_comp, max_iter):
        from sklearn.mixture import BayesianGaussianMixture as DPGMM
        dpgmm = DPGMM(
            n_components=n_comp, covariance_type='full', weight_concentration_prior=1e-3,
            weight_concentration_prior_type='dirichlet_process', init_params="kmeans",
            max_iter=100, random_state=0, verbose=0, verbose_interval=10) # init can be "kmeans" or "random"
        dpgmm.fit(fet)
        label = dpgmm.predict(fet)
        return label

    @staticmethod
    def _hdbscan(fet, min_cluster_size, leaf_size, eom_or_leaf):
        import hdbscan
        import numpy as np
        hdbcluster = hdbscan.HDBSCAN(min_samples=2,
                     min_cluster_size=min_cluster_size, 
                     leaf_size=leaf_size,
                     gen_min_span_tree=True, 
                     algorithm='boruvka_kdtree',
                     core_dist_n_jobs=1,
                     prediction_data=False,
                     cluster_selection_method=eom_or_leaf) # eom or leaf 
        clusterer = hdbcluster.fit(fet.astype(np.float64))
#         probmatrix = hdbscan.all_points_membership_vectors(clusterer)
        return clusterer.labels_+1
chongxi commented 5 years ago

Test it

import os
os.popen('ipcluster start -n {}'.format(16))
from spiketag.base import CLU

fet = np.random.randn(10000, 4)
clu = CLU(np.zeros((10000,)).astype(np.int64))
clu._id = 3

@clu.connect
def on_cluster(*args, **kwargs):
    print(clu._id, clu.membership)

print(clu.membership)

method='dpgmm'   # clustering method
dd = cluster()
if method='dpgmm':
    dd.fit('dpgmm', fet,clu,
           n_comp=8,max_iter=400)
else:
    dd.fit('hdbscan', fet,clu,
           min_cluster_size=18, leaf_size=40, eom_or_leaf='eom')

print(clu.membership)
chongxi commented 5 years ago

nonblocking.ipynb.tar.gz

chongxi commented 5 years ago

The next step is to build a status object which serves as a tory in troy to monitor the status of the clu. The fet.clu_status will be upgrade from {False, None, True ....} to {troy, troy, troy ...}

Every status troy should has multiple states and is related to a group: