chr5tphr / zennit

Zennit is a high-level framework in Python using PyTorch for explaining/exploring neural networks using attribution methods like LRP.
Other
183 stars 33 forks source link

Neuralized K-Means: make k-means amenable to neural network explanations #198

Open jackmcrider opened 11 months ago

jackmcrider commented 11 months ago

Description

K-Means finds some centroids $\mathbf{\mu}_1,\ldots,\mathbf{\mu}_K$ and assigns data points to a cluster as $c = {\rm argmin}_k \lbrace \Vert\boldsymbol{x} - \mathbf{\mu}_k\Vert^2\rbrace$, or in code:

c = torch.argmin(torch.cdist(x, centroids)**2)

Neither the assignment $c$ nor the distance $\Vert\boldsymbol{x} - \mathbf{\mu}_c\Vert^2$ are really explainable. The assignment $c$ is not continuous, and the distance is in fact measuring a dissimilarity to the cluster, essentially the opposite of what we want to explain.

Fixes

From Clustering to Cluster Explanation via Neural Networks (2022) present a method based on neuralization. They show that the k-means cluster assignment can be exactly rewritten as a set of piecewise linear functions

f_c(\boldsymbol{x}) = \min_{k\neq c}\lbrace \boldsymbol{w}_{ck}^\top\boldsymbol{x} + b_{ck}\rbrace

and the original k-means cluster assignment can be recovered as $c = {\rm argmax}_k\lbrace f_k(\boldsymbol{x})\rbrace$.

The cluster discriminant $f_c$ is a measure of cluster membership (instead of a dissimilarity) and is also structurally amenable to LRP and friends. It can also be plugged on top of a neural network feature extractor to make deep cluster assignments explainable.

Additional Information