MaartenGr / BERTopic

Leveraging BERT and c-TF-IDF to create easily interpretable topics.
https://maartengr.github.io/BERTopic/
MIT License
6.12k stars 763 forks source link

Replacing Sub-models with already fitted models #1318

Closed daso94 closed 1 year ago

daso94 commented 1 year ago

Hello,

I was wondering if it was possible to replace the submodels (hdbscan_model, umap_model) of a topic_model with already fitted instances of according models? Let's say I want to switch between MiniBatchKMeans models with different cluster sizes as hdbscan_models. Right now, I always have to create a new instance of a topic model and fit the whole pipeline. Is it possible to just replace the sub-model?

Kind regards!

MaartenGr commented 1 year ago

There are a number of tricks to do this. If you already have fitted models, you would still need to reduce the input embeddings in their dimensionality, so that part of the pipeline still needs to run. Having said that, you can skip over the dimensionality reduction part following this guide. This especially helps if you pre-calculated the reduced embeddings. Moreover, you can use the structure of a dimensionality reduction and clustering algorithms at the start to make sure it skips fitting the model.

For example:

class DimensionalityReduction:
    def __init__(self, my_fitted_dim_model):
        self.my_fitted_dim_model = my_fitted_dim_model

    def fit(self, X):
        return self

    def transform(self, X):
        return self.my_fitted_dim_model.transform(X)

dim_model = DimensionalityReduction(my_fitted_dim_model)
topic_model = BERTopic(umap_model=dim_model)

The same could be done for the clustering model:

class ClusterModel:
    def __init__(self, my_fitted_cluster_model):
        self.my_fitted_cluster_model = my_fitted_cluster_model
        self.labels_ = self.my_fitted_cluster_model.labels_

    def fit(self, X):
        return self

    def predict(self, X):
        return self.my_fitted_cluster_model.predict(X)

cluster_model = ClusterModel(my_fitted_cluster_model)
topic_model = BERTopic(hdbscan_model=cluster_model)
daso94 commented 1 year ago

Thank you very much!