MaartenGr / BERTopic

Leveraging BERT and c-TF-IDF to create easily interpretable topics.
https://maartengr.github.io/BERTopic/
MIT License
6.04k stars 757 forks source link

Why does `visualize_topics` use `n_neighbors=2` #1900

Open zilch42 opened 6 months ago

zilch42 commented 6 months ago

Hi Maarten,

Just wondering why the UMAP module in visualize_topics uses n_neighbors=2 rather than something like 15?

https://github.com/MaartenGr/BERTopic/blob/424cefc68ede08ff9f1c7e56ee6103c16c1429c6/bertopic/plotting/_topics.py#L79

A value of 2 results in a lot of topics sitting directly on top of one another, whereas higher values give better topic spread and are more reminiscent of the topic spread when visualizing documents.

n_neighbors=2 image

n_neighbors=15 image

MaartenGr commented 6 months ago

It is a while ago but I remember a lower value to be necessary when you have few topics. That is the thing with updating a parameter on a single dataset, it might "overfit" to the current situation. The difficulty here is finding a value that works with any number of topics. In other words, does your solution also work if you have 3 topics?

zilch42 commented 5 months ago

Hi Maarten, ah that sounds like it could be possible. I've started testing, however it seems that visualize_topics doesn't work at all with low numbers of topics, regardless of n_neighbors.

from bertopic import BERTopic
from sklearn.datasets import fetch_20newsgroups

docs = fetch_20newsgroups(subset='all',  remove=('headers', 'footers', 'quotes'))['data'][0:5000]

topic_model = BERTopic(nr_topics=4)
topics, probs = topic_model.fit_transform(docs)
topic_model.get_topic_info()

image

topic_model.visualize_topics()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[2], line 1
----> 1 topic_model.visualize_topics()

File c:\Users\[path]\lib\site-packages\bertopic\_bertopic.py:2249, in BERTopic.visualize_topics(self, topics, top_n_topics, custom_labels, title, width, height)
   2216 """ Visualize topics, their sizes, and their corresponding words
   2217 
   2218 This visualization is highly inspired by LDAvis, a great visualization
   (...)
   2246 ```
   2247 """
   2248 check_is_fitted(self)
-> 2249 return plotting.visualize_topics(self,
   2250                                  topics=topics,
   2251                                  top_n_topics=top_n_topics,
   2252                                  custom_labels=custom_labels,
   2253                                  title=title,
   2254                                  width=width,
   2255                                  height=height)

File c:\Users\[path]\lib\site-packages\bertopic\plotting\_topics.py:79, in visualize_topics(topic_model, topics, top_n_topics, custom_labels, title, width, height)
     77 if topic_model.topic_embeddings_ is not None:
     78     embeddings = topic_model.topic_embeddings_[indices]
---> 79     embeddings = UMAP(n_neighbors=2, n_components=2, metric='cosine', random_state=42).fit_transform(embeddings)
     80 else:
     81     embeddings = topic_model.c_tf_idf_.toarray()[indices]

File c:\Users\[path]\lib\site-packages\umap\umap_.py:2887, in UMAP.fit_transform(self, X, y, force_all_finite)
   2851 def fit_transform(self, X, y=None, force_all_finite=True):
   2852     """Fit X into an embedded space and return that transformed
   2853     output.
   2854 
   (...)
   2885         Local radii of data points in the embedding (log-transformed).
   2886     """
-> 2887     self.fit(X, y, force_all_finite)
   2888     if self.transform_mode == "embedding":
   2889         if self.output_dens:

File c:\Users\[path]\lib\site-packages\umap\umap_.py:2780, in UMAP.fit(self, X, y, force_all_finite)
   2776 if self.transform_mode == "embedding":
   2777     epochs = (
   2778         self.n_epochs_list if self.n_epochs_list is not None else self.n_epochs
   2779     )
-> 2780     self.embedding_, aux_data = self._fit_embed_data(
   2781         self._raw_data[index],
   2782         epochs,
   2783         init,
   2784         random_state,  # JH why raw data?
   2785     )
   2787     if self.n_epochs_list is not None:
   2788         if "embedding_list" not in aux_data:

File c:\Users\[path]\lib\site-packages\umap\umap_.py:2826, in UMAP._fit_embed_data(self, X, n_epochs, init, random_state)
   2822 def _fit_embed_data(self, X, n_epochs, init, random_state):
   2823     """A method wrapper for simplicial_set_embedding that can be
   2824     replaced by subclasses.
   2825     """
-> 2826     return simplicial_set_embedding(
   2827         X,
   2828         self.graph_,
   2829         self.n_components,
   2830         self._initial_alpha,
   2831         self._a,
   2832         self._b,
   2833         self.repulsion_strength,
   2834         self.negative_sample_rate,
   2835         n_epochs,
   2836         init,
   2837         random_state,
   2838         self._input_distance_func,
   2839         self._metric_kwds,
   2840         self.densmap,
   2841         self._densmap_kwds,
   2842         self.output_dens,
   2843         self._output_distance_func,
   2844         self._output_metric_kwds,
   2845         self.output_metric in ("euclidean", "l2"),
   2846         self.random_state is None,
   2847         self.verbose,
   2848         tqdm_kwds=self.tqdm_kwds,
   2849     )

File c:\Users\[path]\lib\site-packages\umap\umap_.py:1106, in simplicial_set_embedding(data, graph, n_components, initial_alpha, a, b, gamma, negative_sample_rate, n_epochs, init, random_state, metric, metric_kwds, densmap, densmap_kwds, output_dens, output_metric, output_metric_kwds, euclidean_output, parallel, verbose, tqdm_kwds)
   1102     embedding = noisy_scale_coords(
   1103         embedding, random_state, max_coord=10, noise=0.0001
   1104     )
   1105 elif isinstance(init, str) and init == "spectral":
-> 1106     embedding = spectral_layout(
   1107         data,
   1108         graph,
   1109         n_components,
   1110         random_state,
   1111         metric=metric,
   1112         metric_kwds=metric_kwds,
   1113     )
   1114     # We add a little noise to avoid local minima for optimization to come
   1115     embedding = noisy_scale_coords(
   1116         embedding, random_state, max_coord=10, noise=0.0001
   1117     )

File c:\Users\[path]\lib\site-packages\umap\spectral.py:304, in spectral_layout(data, graph, dim, random_state, metric, metric_kwds, tol, maxiter)
    263 def spectral_layout(
    264     data,
    265     graph,
   (...)
    271     maxiter=0
    272 ):
    273     """
    274     Given a graph compute the spectral embedding of the graph. This is
    275     simply the eigenvectors of the laplacian of the graph. Here we use the
   (...)
    302         The spectral embedding of the graph.
    303     """
--> 304     return _spectral_layout(
    305         data=data,
    306         graph=graph,
    307         dim=dim,
    308         random_state=random_state,
    309         metric=metric,
    310         metric_kwds=metric_kwds,
    311         init="random",
    312         tol=tol,
    313         maxiter=maxiter
    314     )

File c:\Users\[path]\lib\site-packages\umap\spectral.py:521, in _spectral_layout(data, graph, dim, random_state, metric, metric_kwds, init, method, tol, maxiter)
    518 X[:, 0] = sqrt_deg / np.linalg.norm(sqrt_deg)
    520 if method == "eigsh":
--> 521     eigenvalues, eigenvectors = scipy.sparse.linalg.eigsh(
    522         L,
    523         k,
    524         which="SM",
    525         ncv=num_lanczos_vectors,
    526         tol=tol or 1e-4,
    527         v0=np.ones(L.shape[0]),
    528         maxiter=maxiter or graph.shape[0] * 5,
    529     )
    530 elif method == "lobpcg":
    531     with warnings.catch_warnings():

File c:\Users\[path]\lib\site-packages\scipy\sparse\linalg\_eigen\arpack\arpack.py:1608, in eigsh(A, k, M, sigma, which, v0, ncv, maxiter, tol, return_eigenvectors, Minv, OPinv, mode)
   1603 warnings.warn("k >= N for N * N square matrix. "
   1604               "Attempting to use scipy.linalg.eigh instead.",
   1605               RuntimeWarning, stacklevel=2)
   1607 if issparse(A):
-> 1608     raise TypeError("Cannot use scipy.linalg.eigh for sparse A with "
   1609                     "k >= N. Use scipy.linalg.eigh(A.toarray()) or"
   1610                     " reduce k.")
   1611 if isinstance(A, LinearOperator):
   1612     raise TypeError("Cannot use scipy.linalg.eigh for LinearOperator "
   1613                     "A with k >= N.")

TypeError: Cannot use scipy.linalg.eigh for sparse A with k >= N. Use scipy.linalg.eigh(A.toarray()) or reduce k.

The cutoff for the above error seems to be around ~4-5 topics including outliers.


Once we are at 5 or more topics and UMAP actually runs, then its success does not seem to be sensitive to n_neighbors.

Here are 4 topics (excluding outliers) with n_neighbors = 2 image

and n_neighbors = 15 image

It does change the distance relationship between the topics at small numbers (maybe desirable or not at that scale), so you could do something like n_neighbors = min(15, len(topic_list)) to control it for smaller topic models if you felt that would be better.

MaartenGr commented 5 months ago

Thanks for running the experiments!

Hmmm, not too sure what would be the best solution here but I think I like your solution most. As long as the number of topics is below a certain value (like 15), we use the number of topics that we actually have.

It might even be better if we opened up the UMAP model as a parameter. That way, users can also select the parameters themselves as long as they create 2D embeddings.

What do you think?

zilch42 commented 5 months ago

Yes, I would be happy with exposing the UMAP parameters. Maybe through a dict like UMAP_kwargs or something. I considered suggesting that but I know you try to keep the parameter space small :)

MaartenGr commented 5 months ago

Thanks for taking the parameter space into account! With .visualize_topics I think we can add such a parameter seeing as it does not directly influence the parameter space that most people use, namely in __init__. Especially considering the decision for 2D representation is made for you and some flexibility would be preferred.