MaartenGr / BERTopic

Leveraging BERT and c-TF-IDF to create easily interpretable topics.
https://maartengr.github.io/BERTopic/
MIT License
6.2k stars 767 forks source link

partial_fit example code does not work #2097

Open IsaacGreenMachine opened 4 months ago

IsaacGreenMachine commented 4 months ago

Have you searched existing issues? 🔎

Desribe the bug

running the example code from the partial_fit example in the docs throws an error.

thanks for the help!

Stack trace:

ValueError                                Traceback (most recent call last)
Cell In[10], line 21
     19 # Incrementally fit the topic model by training on 1000 documents at a time
     20 for index in range(0, len(docs), 1000):
---> 21     topic_model.partial_fit(docs[index: index+1000])

File ~/.venv/lib/python3.12/site-packages/bertopic/_bertopic.py:725, in BERTopic.partial_fit(self, documents, embeddings, y)
    722 umap_embeddings = self._reduce_dimensionality(embeddings, y, partial_fit=True)
    724 # Cluster reduced embeddings
--> 725 documents, self.probabilities_ = self._cluster_embeddings(umap_embeddings, documents, partial_fit=True)
    726 topics = documents.Topic.to_list()
    728 # Map and find new topics

File ~/.venv/lib/python3.12/site-packages/bertopic/_bertopic.py:3771, in BERTopic._cluster_embeddings(self, umap_embeddings, documents, partial_fit, y)
   3769 logger.info("Cluster - Start clustering the reduced embeddings")
   3770 if partial_fit:
-> 3771     self.hdbscan_model = self.hdbscan_model.partial_fit(umap_embeddings)
   3772     labels = self.hdbscan_model.labels_
   3773     documents["Topic"] = labels

File ~/.venv/lib/python3.12/site-packages/sklearn/base.py:1473, in _fit_context.<locals>.decorator.<locals>.wrapper(estimator, *args, **kwargs)
   1466     estimator._validate_params()
   1468 with config_context(
   1469     skip_parameter_validation=(
   1470         prefer_skip_nested_validation or global_skip_validation
   1471     )
   1472 ):
-> 1473     return fit_method(estimator, *args, **kwargs)

File ~/.venv/lib/python3.12/site-packages/sklearn/cluster/_kmeans.py:2277, in MiniBatchKMeans.partial_fit(self, X, y, sample_weight)
   2274     self._n_since_last_reassign = 0
   2276 with _get_threadpool_controller().limit(limits=1, user_api="blas"):
-> 2277     _mini_batch_step(
   2278         X,
   2279         sample_weight=sample_weight,
   2280         centers=self.cluster_centers_,
   2281         centers_new=self.cluster_centers_,
   2282         weight_sums=self._counts,
   2283         random_state=self._random_state,
   2284         random_reassign=self._random_reassign(),
   2285         reassignment_ratio=self.reassignment_ratio,
   2286         verbose=self.verbose,
   2287         n_threads=self._n_threads,
   2288     )
   2290 if self.compute_labels:
   2291     self.labels_, self.inertia_ = _labels_inertia_threadpool_limit(
   2292         X,
   2293         sample_weight,
   2294         self.cluster_centers_,
   2295         n_threads=self._n_threads,
   2296     )

File ~/.venv/lib/python3.12/site-packages/sklearn/cluster/_kmeans.py:1633, in _mini_batch_step(X, sample_weight, centers, centers_new, weight_sums, random_state, random_reassign, reassignment_ratio, verbose, n_threads)
   1578 """Incremental update of the centers for the Minibatch K-Means algorithm.
   1579 
   1580 Parameters
   (...)
   1628     the centers.
   1629 """
   1630 # Perform label assignment to nearest centers
   1631 # For better efficiency, it's better to run _mini_batch_step in a
   1632 # threadpool_limit context than using _labels_inertia_threadpool_limit here
-> 1633 labels, inertia = _labels_inertia(X, sample_weight, centers, n_threads=n_threads)
   1635 # Update centers according to the labels
   1636 if sp.issparse(X):

File ~/.venv/lib/python3.12/site-packages/sklearn/cluster/_kmeans.py:813, in _labels_inertia(X, sample_weight, centers, n_threads, return_inertia)
    810     _labels = lloyd_iter_chunked_dense
    811     _inertia = _inertia_dense
--> 813 _labels(
    814     X,
    815     sample_weight,
    816     centers,
    817     centers_new=None,
    818     weight_in_clusters=None,
    819     labels=labels,
    820     center_shift=center_shift,
    821     n_threads=n_threads,
    822     update_centers=False,
    823 )
    825 if return_inertia:
    826     inertia = _inertia(X, sample_weight, centers, labels, n_threads)

File _k_means_lloyd.pyx:26, in sklearn.cluster._k_means_lloyd.lloyd_iter_chunked_dense()

ValueError: Buffer dtype mismatch, expected 'const double' but got 'float'

Reproduction

copy and pasted from example code here

I'm on an M3 MacBook Pro Python 3.12.4 scikit-learn 1.5.1 bertopic 0.16.3 numpy 1.26.4 scipy 1.14.0

from sklearn.datasets import fetch_20newsgroups
from sklearn.cluster import MiniBatchKMeans
from sklearn.decomposition import IncrementalPCA
from bertopic.vectorizers import OnlineCountVectorizer
from bertopic import BERTopic

# Prepare documents
docs = fetch_20newsgroups(subset='all',  remove=('headers', 'footers', 'quotes'))["data"]

# Prepare sub-models that support online learning
umap_model = IncrementalPCA(n_components=5)
cluster_model = MiniBatchKMeans(n_clusters=50, random_state=0)
vectorizer_model = OnlineCountVectorizer(stop_words="english", decay=.01)

topic_model = BERTopic(umap_model=umap_model,
                       hdbscan_model=cluster_model,
                       vectorizer_model=vectorizer_model)

# Incrementally fit the topic model by training on 1000 documents at a time
for index in range(0, len(docs), 1000):
    topic_model.partial_fit(docs[index: index+1000])

BERTopic Version

0.16.3

MaartenGr commented 4 months ago

Hmmm, I have seen this issue with a recent scikit-learn update but it seems there isn't a fix as of yet. You could try the solution suggested here perhaps each time after a partial fit to see whether that helps.

bbrk24 commented 2 months ago

The workaround in the linked issue didn't work for me, but what did work was downgrading scikit-learn to 1.4.2.

bazanov-aleksey commented 1 week ago

The workaround in the linked issue didn't work for me, but what did work was downgrading scikit-learn to 1.4.2.

thank you, it works for 1.4.2!!! (I used 1.5.2 version of scikit-learn)

I'm on an Ubuntu 22.04 Python 3.11.0 bertopic 0.16.4 numpy 2.0.2 scipy 1.14.1