SpikeInterface / spikeinterface

A Python-based module for creating flexible and robust spike sorting pipelines.
MIT License
500 stars 187 forks source link

`MultiSortingComparison` and `MultiTemplateComparison` optimal assignment #2911

Open florian6973 opened 4 months ago

florian6973 commented 4 months ago


I have been looking at the literature on Multidimensional Assignment Problems / Entity Matching to understand multiple sorting or template assignments, and I realized that the current method does not seem to always return the optimal matching.

To check my hypothesis, I modified the BaseMultiComparison class to create a minimal working example with the potential issue (example from https://arxiv.org/pdf/2112.03346 page 3), it can be run as a script:

from collections import OrderedDict
from copy import deepcopy
import numpy as np

class BaseMultiComparison():
    Base class for graph-based multi comparison classes.

    It handles graph operations, comparisons, and agreements.

    def __init__(self):
        import networkx as nx

        # BaseComparison.__init__(
        #     self,
        #     object_list=object_list,
        #     name_list=name_list,
        #     match_score=match_score,
        #     chance_score=chance_score,
        #     verbose=verbose,
        # )
        # self.match_score = 0.3
        self.name_list = ['a', 'b', 'c']   
        self.object_list = ['1', '2', '3']   
        self._verbose = True

        self.graph = None
        self.subgraphs = None
        self.clean_graph = None

    def _compute_all(self):

    def _populate_nodes(self):
        for name in self.name_list:
            for unit_id in self.object_list:
                self.graph.add_node((name, unit_id))

    def units(self):
        return deepcopy(self._new_units)

    def compute_subgraphs(self):
        Computes subgraphs of connected components.
        sg_object_names: list
            List of sorter names for each node in the connected component subgraph
        sg_units: list
            List of unit ids for each node in the connected component subgraph
        if self.clean_graph is not None:
            g = self.clean_graph
            g = self.graph

        import networkx as nx

        subgraphs = (g.subgraph(c).copy() for c in nx.connected_components(g))
        sg_object_names = []
        sg_units = []
        for sg in subgraphs:
            object_names = []
            unit_names = []
            for node in sg.nodes:
        return sg_object_names, sg_units

    def _do_comparison(
        # do pairwise matching
        if self._verbose:
            print("Multicomparison step 1: pairwise comparison")

        self.comparisons = {
            ('a', 'b'): {
                '1': ('2', 0.6), '2': ('1', 0.6), '3': ('3', 1.0)
             ('b', 'c'): {
                '1': ('1', 1.0), '2': ('2', 1.0), '3': ('3', 1.0)
             ('a', 'c'): {
                '1': ('1', 1.0), '2': ('2', 1.0), '3': ('3', 1.0)

    def _do_graph(self):
        if self._verbose:
            print("Multicomparison step 2: make graph")

        import networkx as nx

        self.graph = nx.Graph()
        # nodes

        # edges
        for comp_name, comp in self.comparisons.items():
            for u1 in comp.keys():
                u2 = comp[u1][0]
                if u2 != -1:
                    name_1, name_2 = comp_name
                    node1 = name_1, u1
                    node2 = name_2, u2
                    score = comp[u1][1]
                    self.graph.add_edge(node1, node2, weight=score)

        # the graph is symmetrical
        self.graph = self.graph.to_undirected()

    def _clean_graph(self):
        if self._verbose:
            print("Multicomparison step 3: clean graph")
        clean_graph = self.graph.copy()
        import networkx as nx

        subgraphs = (clean_graph.subgraph(c).copy() for c in nx.connected_components(clean_graph))
        removed_nodes = 0
        for sg in subgraphs:
            object_names = []
            for node in sg.nodes:
            sorters, counts = np.unique(object_names, return_counts=True)

            if np.any(counts > 1):
                for sorter in sorters[counts > 1]:
                    nodes_duplicate = [n for n in sg.nodes if sorter in n]
                    # get edges
                    edges_duplicates = []
                    weights_duplicates = []
                    for n in nodes_duplicate:
                        edges = sg.edges(n, data=True)
                        for e in edges:

                    # remove extra edges
                    n_edges_to_remove = len(nodes_duplicate) - 1
                    remove_idxs = np.argsort(weights_duplicates)[:n_edges_to_remove]
                    edges_to_remove = np.array(edges_duplicates, dtype=object)[remove_idxs]

                    for edge_to_remove in edges_to_remove:
                        clean_graph.remove_edge(edge_to_remove[0], edge_to_remove[1])
                        sg.remove_edge(edge_to_remove[0], edge_to_remove[1])
                        if self._verbose:
                            print(f"Removed edge: {edge_to_remove}")

                    # remove extra nodes (as a second step to not affect edge removal)
                    for edge_to_remove in edges_to_remove:
                        if edge_to_remove[0] in nodes_duplicate:
                            node_to_remove = edge_to_remove[0]
                            node_to_remove = edge_to_remove[1]
                        if node_to_remove in sg.nodes:
                            print(f"Removed node: {node_to_remove}")
                            removed_nodes += 1

        if self._verbose:
            print(f"Removed {removed_nodes} duplicate nodes")
        self.clean_graph = clean_graph

    def _do_agreement(self):
        # extract agreement from graph
        if self._verbose:
            print("Multicomparison step 4: extract agreement from graph")

        self._new_units = {}

        # save new units
        import networkx as nx

        self.subgraphs = [self.clean_graph.subgraph(c).copy() for c in nx.connected_components(self.clean_graph)]
        for new_unit, sg in enumerate(self.subgraphs):
            edges = list(sg.edges(data=True))
            if len(edges) > 0:
                avg_agr = np.mean([d["weight"] for u, v, d in edges])
                avg_agr = 0
            object_unit_ids = {}
            for node in sg.nodes:
                object_name, unit_name = node
                object_unit_ids[object_name] = unit_name
            # sort dict based on name list
            sorted_object_unit_ids = OrderedDict()
            for name in self.name_list:
                if name in object_unit_ids:
                    sorted_object_unit_ids[name] = object_unit_ids[name]
            self._new_units[new_unit] = {
                "avg_agreement": avg_agr,
                "unit_ids": sorted_object_unit_ids,
                "agreement_number": len(sg.nodes),
b = BaseMultiComparison()

Therefore, according to you, is my MWE correctly adapted from the literature to the spikeinterface framework? If so, have you envisioned other methods so far or should we think more about it to solve this issue please?



zm711 commented 4 months ago


we will take a look at this soon. We are in the middle of a spikeinterface hackathon, but super curious about this. It is a little hard for me to read the code (without having a nice diff view). Could you also post the same code with comments on the lines you changed to make comparison a bit easier. If we haven't responded by next week please ping us again!

florian6973 commented 4 months ago

Thanks for your reply!

Sure, here are some more details:

$$\begin{array}{c|ccc} & b_1 & b_2 & b_3 \ \hline a_1 & 0.4 & \textbf{0.6} & 0.6 \ a_2 & \textbf{0.6} & 0.6 & 0.6 \ a_3 & 0.6 & 0.6 & \textbf{1} \end{array}$$

$$\begin{array}{c|ccc} & c_1 & c_2 & c_3 \ \hline b_1 & \textbf{1} & 0.1 & 0.1 \ b_2 & 0.1 & \textbf{1} & 0.1 \ b_3 & 0.1 & 0.1 & \textbf{1} \end{array}$$

$$\begin{array}{c|ccc} & c_1 & c_2 & c_3 \ \hline a_1 & \textbf{1} & 0.1 & 0.1 \ a_2 & 0.1 & \textbf{1} & 0.1 \ a_3 & 0.1 & 0.1 & \textbf{1} \end{array}$$

    def __init__(self):
        self.name_list = ['a', 'b', 'c']   
        self.object_list = ['1', '2', '3']   

    # def _compare_ij(self, i, j):
    #   raise NotImplementedError

    # def _populate_nodes(self):
    #    raise NotImplementedError

    def _populate_nodes(self):
        for name in self.name_list:
            for unit_id in self.object_list:
                self.graph.add_node((name, unit_id))

#     def _do_comparison(
#         self,
#     ):
#         # do pairwise matching
#         if self._verbose:
#             print("Multicomparison step 1: pairwise comparison")

#         self.comparisons = {}
#         for i in range(len(self.object_list)):
#             for j in range(i + 1, len(self.object_list)):
#                 if self.name_list is not None:
#                     name_i = self.name_list[i]
#                     name_j = self.name_list[j]
#                 else:
#                     name_i = "object i"
#                     name_j = "object j"
#                 if self._verbose:
#                     print(f"  Comparing: {name_i} and {name_j}")
#                 comp = self._compare_ij(i, j)
#                 self.comparisons[(name_i, name_j)] = comp

    def _do_comparison(
        # do pairwise matching
        if self._verbose:
            print("Multicomparison step 1: pairwise comparison")

        self.comparisons = {
            ('a', 'b'): {
                '1': ('2', 0.6), '2': ('1', 0.6), '3': ('3', 1.0)
             ('b', 'c'): {
                '1': ('1', 1.0), '2': ('2', 1.0), '3': ('3', 1.0)
             ('a', 'c'): {
                '1': ('1', 1.0), '2': ('2', 1.0), '3': ('3', 1.0)

    def _do_graph(self):
# ...
#  for comp_name, comp in self.comparisons.items():
#     for u1 in comp.hungarian_match_12.index.values:
#         u2 = comp.hungarian_match_12[u1]
#         if u2 != -1:
#             name_1, name_2 = comp_name
#             node1 = name_1, u1
#             node2 = name_2, u2
#             score = comp.agreement_scores.loc[u1, u2]
#             self.graph.add_edge(node1, node2, weight=score)

        for comp_name, comp in self.comparisons.items():
            for u1 in comp.keys():
                u2 = comp[u1][0]
                if u2 != -1:
                    name_1, name_2 = comp_name
                    node1 = name_1, u1
                    node2 = name_2, u2
                    score = comp[u1][1]
                    self.graph.add_edge(node1, node2, weight=score)

I hope this is clearer. I am not sure if I am fully correct, but I was trying to properly understand the multiple comparison module, so that's why I am asking.

Have a good hackathon :)

By the way, if you are in Boston at some point we could discuss it in person if needed :)

zm711 commented 4 months ago

Hey @florian6973,

thanks for the well wishes. We could definitely meet at some point. If you're on the slack just send me a message. But I think @alejoe91 is better for looking over this one. I didn't work on the initial code so he would know it way better.

JoeZiminski commented 3 months ago

Thank a lot for this @florian6973, super interesting and detailed investigation. Will definitely look into this while working on #2626, please feel free to give any feedback and thoughts on the plan I posted there.