Open zilch42 opened 6 months ago
It is a while ago but I remember a lower value to be necessary when you have few topics. That is the thing with updating a parameter on a single dataset, it might "overfit" to the current situation. The difficulty here is finding a value that works with any number of topics. In other words, does your solution also work if you have 3 topics?
Hi Maarten, ah that sounds like it could be possible. I've started testing, however it seems that visualize_topics
doesn't work at all with low numbers of topics, regardless of n_neighbors
.
from bertopic import BERTopic
from sklearn.datasets import fetch_20newsgroups
docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'][0:5000]
topic_model = BERTopic(nr_topics=4)
topics, probs = topic_model.fit_transform(docs)
topic_model.get_topic_info()
topic_model.visualize_topics()
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[2], line 1
----> 1 topic_model.visualize_topics()
File c:\Users\[path]\lib\site-packages\bertopic\_bertopic.py:2249, in BERTopic.visualize_topics(self, topics, top_n_topics, custom_labels, title, width, height)
2216 """ Visualize topics, their sizes, and their corresponding words
2217
2218 This visualization is highly inspired by LDAvis, a great visualization
(...)
2246 ```
2247 """
2248 check_is_fitted(self)
-> 2249 return plotting.visualize_topics(self,
2250 topics=topics,
2251 top_n_topics=top_n_topics,
2252 custom_labels=custom_labels,
2253 title=title,
2254 width=width,
2255 height=height)
File c:\Users\[path]\lib\site-packages\bertopic\plotting\_topics.py:79, in visualize_topics(topic_model, topics, top_n_topics, custom_labels, title, width, height)
77 if topic_model.topic_embeddings_ is not None:
78 embeddings = topic_model.topic_embeddings_[indices]
---> 79 embeddings = UMAP(n_neighbors=2, n_components=2, metric='cosine', random_state=42).fit_transform(embeddings)
80 else:
81 embeddings = topic_model.c_tf_idf_.toarray()[indices]
File c:\Users\[path]\lib\site-packages\umap\umap_.py:2887, in UMAP.fit_transform(self, X, y, force_all_finite)
2851 def fit_transform(self, X, y=None, force_all_finite=True):
2852 """Fit X into an embedded space and return that transformed
2853 output.
2854
(...)
2885 Local radii of data points in the embedding (log-transformed).
2886 """
-> 2887 self.fit(X, y, force_all_finite)
2888 if self.transform_mode == "embedding":
2889 if self.output_dens:
File c:\Users\[path]\lib\site-packages\umap\umap_.py:2780, in UMAP.fit(self, X, y, force_all_finite)
2776 if self.transform_mode == "embedding":
2777 epochs = (
2778 self.n_epochs_list if self.n_epochs_list is not None else self.n_epochs
2779 )
-> 2780 self.embedding_, aux_data = self._fit_embed_data(
2781 self._raw_data[index],
2782 epochs,
2783 init,
2784 random_state, # JH why raw data?
2785 )
2787 if self.n_epochs_list is not None:
2788 if "embedding_list" not in aux_data:
File c:\Users\[path]\lib\site-packages\umap\umap_.py:2826, in UMAP._fit_embed_data(self, X, n_epochs, init, random_state)
2822 def _fit_embed_data(self, X, n_epochs, init, random_state):
2823 """A method wrapper for simplicial_set_embedding that can be
2824 replaced by subclasses.
2825 """
-> 2826 return simplicial_set_embedding(
2827 X,
2828 self.graph_,
2829 self.n_components,
2830 self._initial_alpha,
2831 self._a,
2832 self._b,
2833 self.repulsion_strength,
2834 self.negative_sample_rate,
2835 n_epochs,
2836 init,
2837 random_state,
2838 self._input_distance_func,
2839 self._metric_kwds,
2840 self.densmap,
2841 self._densmap_kwds,
2842 self.output_dens,
2843 self._output_distance_func,
2844 self._output_metric_kwds,
2845 self.output_metric in ("euclidean", "l2"),
2846 self.random_state is None,
2847 self.verbose,
2848 tqdm_kwds=self.tqdm_kwds,
2849 )
File c:\Users\[path]\lib\site-packages\umap\umap_.py:1106, in simplicial_set_embedding(data, graph, n_components, initial_alpha, a, b, gamma, negative_sample_rate, n_epochs, init, random_state, metric, metric_kwds, densmap, densmap_kwds, output_dens, output_metric, output_metric_kwds, euclidean_output, parallel, verbose, tqdm_kwds)
1102 embedding = noisy_scale_coords(
1103 embedding, random_state, max_coord=10, noise=0.0001
1104 )
1105 elif isinstance(init, str) and init == "spectral":
-> 1106 embedding = spectral_layout(
1107 data,
1108 graph,
1109 n_components,
1110 random_state,
1111 metric=metric,
1112 metric_kwds=metric_kwds,
1113 )
1114 # We add a little noise to avoid local minima for optimization to come
1115 embedding = noisy_scale_coords(
1116 embedding, random_state, max_coord=10, noise=0.0001
1117 )
File c:\Users\[path]\lib\site-packages\umap\spectral.py:304, in spectral_layout(data, graph, dim, random_state, metric, metric_kwds, tol, maxiter)
263 def spectral_layout(
264 data,
265 graph,
(...)
271 maxiter=0
272 ):
273 """
274 Given a graph compute the spectral embedding of the graph. This is
275 simply the eigenvectors of the laplacian of the graph. Here we use the
(...)
302 The spectral embedding of the graph.
303 """
--> 304 return _spectral_layout(
305 data=data,
306 graph=graph,
307 dim=dim,
308 random_state=random_state,
309 metric=metric,
310 metric_kwds=metric_kwds,
311 init="random",
312 tol=tol,
313 maxiter=maxiter
314 )
File c:\Users\[path]\lib\site-packages\umap\spectral.py:521, in _spectral_layout(data, graph, dim, random_state, metric, metric_kwds, init, method, tol, maxiter)
518 X[:, 0] = sqrt_deg / np.linalg.norm(sqrt_deg)
520 if method == "eigsh":
--> 521 eigenvalues, eigenvectors = scipy.sparse.linalg.eigsh(
522 L,
523 k,
524 which="SM",
525 ncv=num_lanczos_vectors,
526 tol=tol or 1e-4,
527 v0=np.ones(L.shape[0]),
528 maxiter=maxiter or graph.shape[0] * 5,
529 )
530 elif method == "lobpcg":
531 with warnings.catch_warnings():
File c:\Users\[path]\lib\site-packages\scipy\sparse\linalg\_eigen\arpack\arpack.py:1608, in eigsh(A, k, M, sigma, which, v0, ncv, maxiter, tol, return_eigenvectors, Minv, OPinv, mode)
1603 warnings.warn("k >= N for N * N square matrix. "
1604 "Attempting to use scipy.linalg.eigh instead.",
1605 RuntimeWarning, stacklevel=2)
1607 if issparse(A):
-> 1608 raise TypeError("Cannot use scipy.linalg.eigh for sparse A with "
1609 "k >= N. Use scipy.linalg.eigh(A.toarray()) or"
1610 " reduce k.")
1611 if isinstance(A, LinearOperator):
1612 raise TypeError("Cannot use scipy.linalg.eigh for LinearOperator "
1613 "A with k >= N.")
TypeError: Cannot use scipy.linalg.eigh for sparse A with k >= N. Use scipy.linalg.eigh(A.toarray()) or reduce k.
The cutoff for the above error seems to be around ~4-5 topics including outliers.
Once we are at 5 or more topics and UMAP actually runs, then its success does not seem to be sensitive to n_neighbors
.
Here are 4 topics (excluding outliers) with n_neighbors = 2
and n_neighbors = 15
It does change the distance relationship between the topics at small numbers (maybe desirable or not at that scale), so you could do something like n_neighbors = min(15, len(topic_list))
to control it for smaller topic models if you felt that would be better.
Thanks for running the experiments!
Hmmm, not too sure what would be the best solution here but I think I like your solution most. As long as the number of topics is below a certain value (like 15), we use the number of topics that we actually have.
It might even be better if we opened up the UMAP model as a parameter. That way, users can also select the parameters themselves as long as they create 2D embeddings.
What do you think?
Yes, I would be happy with exposing the UMAP parameters. Maybe through a dict like UMAP_kwargs
or something. I considered suggesting that but I know you try to keep the parameter space small :)
Thanks for taking the parameter space into account! With .visualize_topics
I think we can add such a parameter seeing as it does not directly influence the parameter space that most people use, namely in __init__
. Especially considering the decision for 2D representation is made for you and some flexibility would be preferred.
Hi Maarten,
Just wondering why the UMAP module in
visualize_topics
usesn_neighbors=2
rather than something like 15?https://github.com/MaartenGr/BERTopic/blob/424cefc68ede08ff9f1c7e56ee6103c16c1429c6/bertopic/plotting/_topics.py#L79
A value of 2 results in a lot of topics sitting directly on top of one another, whereas higher values give better topic spread and are more reminiscent of the topic spread when visualizing documents.
n_neighbors=2
n_neighbors=15