MaartenGr / BERTopic

Leveraging BERT and c-TF-IDF to create easily interpretable topics.
https://maartengr.github.io/BERTopic/
MIT License
5.99k stars 752 forks source link

BERTopic with Sklearn Gaussian Mixture Model #1312

Closed drmwnrafi closed 1 year ago

drmwnrafi commented 1 year ago

I'm trying to use GMM with BERTopic for clustering. GMM uses the sklearn library. Here's the code:

from sentence_transformers import SentenceTransformer

sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
new_embedd = sentence_model.encode(tokopedia['final'])

from sklearn.decomposition import PCA
gmm = GaussianMixture(n_components=3)
dim_model = PCA(n_components=2)
topic_model_gmm = BERTopic(embedding_model=sentence_model, hdbscan_model=gmm, umap_model=dim_model)

And this the error: TypeError: int() argument must be a string, a bytes-like object or a real number, not 'NoneType'

but when using KMeans it works fine

MaartenGr commented 1 year ago

Could you share the full error log? That will help to understand what the exact issue is.

drmwnrafi commented 1 year ago

Sorry, I'm new in this field. I'm confused about the error log, do you mean this?

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-24-0d41ca5531fe> in <cell line: 1>()
----> 1 topics, probs = topic_model_gmm.fit_transform(tokopedia['final'])

9 frames
/usr/local/lib/python3.10/dist-packages/bertopic/_bertopic.py in fit_transform(self, documents, embeddings, images, y)
    387 
    388         # Cluster reduced embeddings
--> 389         documents, probabilities = self._cluster_embeddings(umap_embeddings, documents, y=y)
    390 
    391         # Sort and Map Topic IDs by their frequency

/usr/local/lib/python3.10/dist-packages/bertopic/_bertopic.py in _cluster_embeddings(self, umap_embeddings, documents, partial_fit, y)
   3225                 labels = y
   3226             documents['Topic'] = labels
-> 3227             self._update_topic_size(documents)
   3228 
   3229         # Some algorithms have outlier labels (-1) that can be tricky to work

/usr/local/lib/python3.10/dist-packages/bertopic/_bertopic.py in _update_topic_size(self, documents)
   3518         """
   3519         self.topic_sizes_ = collections.Counter(documents.Topic.values.tolist())
-> 3520         self.topics_ = documents.Topic.astype(int).tolist()
   3521 
   3522     def _extract_words_per_topic(self,

/usr/local/lib/python3.10/dist-packages/pandas/core/generic.py in astype(self, dtype, copy, errors)
   6238         else:
   6239             # else, only a single dtype is given
-> 6240             new_data = self._mgr.astype(dtype=dtype, copy=copy, errors=errors)
   6241             return self.constructor(new_data).finalize_(self, method="astype")
   6242 

/usr/local/lib/python3.10/dist-packages/pandas/core/internals/managers.py in astype(self, dtype, copy, errors)
    446 
    447     def astype(self: T, dtype, copy: bool = False, errors: str = "raise") -> T:
--> 448         return self.apply("astype", dtype=dtype, copy=copy, errors=errors)
    449 
    450     def convert(

/usr/local/lib/python3.10/dist-packages/pandas/core/internals/managers.py in apply(self, f, align_keys, ignore_failures, **kwargs)
    350                     applied = b.apply(f, **kwargs)
    351                 else:
--> 352                     applied = getattr(b, f)(**kwargs)
    353             except (TypeError, NotImplementedError):
    354                 if not ignore_failures:

/usr/local/lib/python3.10/dist-packages/pandas/core/internals/blocks.py in astype(self, dtype, copy, errors)
    524         values = self.values
    525 
--> 526         new_values = astype_array_safe(values, dtype, copy=copy, errors=errors)
    527 
    528         new_values = maybe_coerce_values(new_values)

/usr/local/lib/python3.10/dist-packages/pandas/core/dtypes/astype.py in astype_array_safe(values, dtype, copy, errors)
    297 
    298     try:
--> 299         new_values = astype_array(values, dtype, copy=copy)
    300     except (ValueError, TypeError):
    301         # e.g. astype_nansafe can fail on object-dtype of strings

/usr/local/lib/python3.10/dist-packages/pandas/core/dtypes/astype.py in astype_array(values, dtype, copy)
    228 
    229     else:
--> 230         values = astype_nansafe(values, dtype, copy=copy)
    231 
    232     # in pandas we don't store numpy str dtypes, so convert to object

/usr/local/lib/python3.10/dist-packages/pandas/core/dtypes/astype.py in astype_nansafe(arr, dtype, copy, skipna)
    168     if copy or is_object_dtype(arr.dtype) or is_object_dtype(dtype):
    169         # Explicit copy, or required since NumPy can't view from / to object.
--> 170         return arr.astype(dtype, copy=True)
    171 
    172     return arr.astype(dtype, copy=copy)

TypeError: int() argument must be a string, a bytes-like object or a real number, not 'NoneType'

This is the full code :

from sentence_transformers import SentenceTransformer

sentence_model = SentenceTransformer("all-mpnet-base-v2")
new_embedd = sentence_model.encode(tokopedia['final'])

from sklearn.decomposition import PCA

pca = PCA(n_components=2)
PCA = pca.fit(new_embedd)
embeddings_pca = pca.transform(new_embedd)

from sklearn.metrics import silhouette_score

S=[]

K=range(2,12)

for k in K:
    model = GaussianMixture(n_components=k)
    labels = model.fit_predict(embeddings_pca)
    S.append(silhouette_score(embeddings_pca, labels))

plt.figure(figsize=(16,8), dpi=300)
plt.plot(K, S, 'bo-', color='black')
plt.xlabel('k')
plt.ylabel('Silhouette Score')
plt.title('Identify the number of clusters using Silhouette Score')
plt.show()

from sklearn.decomposition import PCA

gmm = GaussianMixture(n_components=3)
dim_model = PCA(n_components=2)
topic_model_gmm = BERTopic(embedding_model=sentence_model, hdbscan_model=gmm, umap_model=dim_model)

topics, probs = topic_model_gmm.fit_transform(tokopedia['final'])

Thank you,

MaartenGr commented 1 year ago

I believe that your model generates an empty .labels_ attribute which is necessary for BERTopic to use the underlying model. Perhaps you can adjust the model such that it does so:

class ClusterModel:
    def __init__(self, model):
        self.model= model

    def fit(self, X):
        return self

    def predict(self, X):
        predictions = self.model.predict(X)
        self.labels_ = predictions
        return predictions

cluster_model = ClusterModel(GaussianMixture(n_components=3))
topic_model = BERTopic(hdbscan_model=cluster_model)

Not sure whether this is correct (haven't tested this out) but it might work.

drmwnrafi commented 1 year ago

I've tried your suggestion, but still the same error

---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

[<ipython-input-209-0d41ca5531fe>](https://localhost:8080/#) in <cell line: 1>()
----> 1 topics, probs = topic_model_gmm.fit_transform(tokopedia['final'])

9 frames

[/usr/local/lib/python3.10/dist-packages/bertopic/_bertopic.py](https://localhost:8080/#) in fit_transform(self, documents, embeddings, images, y)
    387 
    388         # Cluster reduced embeddings
--> 389         documents, probabilities = self._cluster_embeddings(umap_embeddings, documents, y=y)
    390 
    391         # Sort and Map Topic IDs by their frequency

[/usr/local/lib/python3.10/dist-packages/bertopic/_bertopic.py](https://localhost:8080/#) in _cluster_embeddings(self, umap_embeddings, documents, partial_fit, y)
   3225                 labels = y
   3226             documents['Topic'] = labels
-> 3227             self._update_topic_size(documents)
   3228 
   3229         # Some algorithms have outlier labels (-1) that can be tricky to work

[/usr/local/lib/python3.10/dist-packages/bertopic/_bertopic.py](https://localhost:8080/#) in _update_topic_size(self, documents)
   3518         """
   3519         self.topic_sizes_ = collections.Counter(documents.Topic.values.tolist())
-> 3520         self.topics_ = documents.Topic.astype(int).tolist()
   3521 
   3522     def _extract_words_per_topic(self,

[/usr/local/lib/python3.10/dist-packages/pandas/core/generic.py](https://localhost:8080/#) in astype(self, dtype, copy, errors)
   6238         else:
   6239             # else, only a single dtype is given
-> 6240             new_data = self._mgr.astype(dtype=dtype, copy=copy, errors=errors)
   6241             return self._constructor(new_data).__finalize__(self, method="astype")
   6242 

[/usr/local/lib/python3.10/dist-packages/pandas/core/internals/managers.py](https://localhost:8080/#) in astype(self, dtype, copy, errors)
    446 
    447     def astype(self: T, dtype, copy: bool = False, errors: str = "raise") -> T:
--> 448         return self.apply("astype", dtype=dtype, copy=copy, errors=errors)
    449 
    450     def convert(

[/usr/local/lib/python3.10/dist-packages/pandas/core/internals/managers.py](https://localhost:8080/#) in apply(self, f, align_keys, ignore_failures, **kwargs)
    350                     applied = b.apply(f, **kwargs)
    351                 else:
--> 352                     applied = getattr(b, f)(**kwargs)
    353             except (TypeError, NotImplementedError):
    354                 if not ignore_failures:

[/usr/local/lib/python3.10/dist-packages/pandas/core/internals/blocks.py](https://localhost:8080/#) in astype(self, dtype, copy, errors)
    524         values = self.values
    525 
--> 526         new_values = astype_array_safe(values, dtype, copy=copy, errors=errors)
    527 
    528         new_values = maybe_coerce_values(new_values)

[/usr/local/lib/python3.10/dist-packages/pandas/core/dtypes/astype.py](https://localhost:8080/#) in astype_array_safe(values, dtype, copy, errors)
    297 
    298     try:
--> 299         new_values = astype_array(values, dtype, copy=copy)
    300     except (ValueError, TypeError):
    301         # e.g. astype_nansafe can fail on object-dtype of strings

[/usr/local/lib/python3.10/dist-packages/pandas/core/dtypes/astype.py](https://localhost:8080/#) in astype_array(values, dtype, copy)
    228 
    229     else:
--> 230         values = astype_nansafe(values, dtype, copy=copy)
    231 
    232     # in pandas we don't store numpy str dtypes, so convert to object

[/usr/local/lib/python3.10/dist-packages/pandas/core/dtypes/astype.py](https://localhost:8080/#) in astype_nansafe(arr, dtype, copy, skipna)
    168     if copy or is_object_dtype(arr.dtype) or is_object_dtype(dtype):
    169         # Explicit copy, or required since NumPy can't view from / to object.
--> 170         return arr.astype(dtype, copy=True)
    171 
    172     return arr.astype(dtype, copy=copy)

TypeError: int() argument must be a string, a bytes-like object or a real number, not 'NoneType'

This the full code:

from sklearn.decomposition import PCA
from sentence_transformers import SentenceTransformer

class ClusterModel:
    def __init__(self, model):
        self.model= model

    def fit(self, X):
        return self

    def predict(self, X):
        predictions = self.model.predict(X)
        self.labels_ = predictions
        return predictions

sentence_model = SentenceTransformer("all-mpnet-base-v2")

gmm = ClusterModel(GaussianMixture(n_components=3, random_state=random_state))
dim_model = PCA(n_components=2)
topic_model_gmm = BERTopic(embedding_model=sentence_model, hdbscan_model=gmm, umap_model=dim_model)
drmwnrafi commented 1 year ago

This code can run without error, but i don't know the code correct or not :

from sklearn.decomposition import PCA

class ClusterModel:
    def __init__(self, model):
        self.model = model

    def fit(self, X, y=None):
        self.model.fit(X)
        self.labels_ = self.model.predict(X)
        return self

    def predict(self, X):
        predictions = self.model.predict(X)
        self.labels_ = predictions
        return predictions

gmm = ClusterModel(GaussianMixture(n_components=3, random_state=random_state))
dim_model = PCA(n_components=2)
topic_model_gmm = BERTopic(language='indonesian', embedding_model=sentence_model, hdbscan_model=gmm, umap_model=dim_model)
MaartenGr commented 1 year ago

Ah right, I forgot to update the .fit function. It seems correct.