MaartenGr / BERTopic

Leveraging BERT and c-TF-IDF to create easily interpretable topics.
https://maartengr.github.io/BERTopic/
MIT License
6.18k stars 765 forks source link

Problem with saving the model #1431

Open donottakemyusername opened 1 year ago

donottakemyusername commented 1 year ago

Hi, I am using the partial_fit function to perform incremental learning with BERTopic. When I tried to save the BERTopic model using safetensors, I got the following error: KeyError: 'tokenizer'. The error was raised in bertopic/_save_utils.py when the function tries to recreate the countvectorizer delete the parameters in cv but they don't actually exist. I tried to save the model using the code: model.save('some_directory', serialization="safetensors", save_ctfidf=True), and here is the error code I got: /python3.9/site-packages/bertopic/_save_utils.py in save_ctfidf_config(model, path) 293 # Recreate CountVectorizer 294 cv_params = model.vectorizer_model.get_params() --> 295 del cv_params["tokenizer"], cv_params["preprocessor"], cv_params["dtype"] 296 if not isinstance(cv_params["analyzer"], str): 297 del cv_params["analyzer"]

KeyError: 'tokenizer'

I have run the function model.vectorizer_model.get_params() and it only contains 2 parameters: {'decay': 0.05, 'delete_min_df': None}. Is there anything I've done wrong? Thank you!

MaartenGr commented 1 year ago

I am not sure whether you actually did something wrong here. Could you share your full code for training and saving the model? I think you could still use serialization="pickle" but that might not be what you are looking for.

daviddexter commented 1 year ago

Hi @MaartenGr I'm experiencing the same problem. Here is my code:


class WrappedRiverClusterAlgo:
    """Wraps a River model so that it can be used to train the model in chunks of data similar
    to online training
    """
    def __init__(self, model):
        self.model = model

    def partial_fit(self, umap_embeddings):
        for umap_embedding, _ in stream.iter_array(umap_embeddings):
            self.model = self.model.learn_one(umap_embedding)

        labels = []
        for umap_embedding, _ in stream.iter_array(umap_embeddings):
            label = self.model.predict_one(umap_embedding)
            labels.append(label)

        self.labels_ = labels
        return self

# Step 1 - Extract embeddings           
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')                       

# Step 2 - Reduce dimensionality       
umap_model = IncrementalPCA(n_components=5)

# Step 3 - Cluster reduced embeddings             
cluster_model = WrappedRiverClusterAlgo(cluster.CluStream())

# Step 4 - Tokenize topics          
vectorizer_model = OnlineCountVectorizer(stop_words="english", decay=.01, delete_min_df=10.00,
                                                 ngram_range=(2,2))        

# Step 5 - Create topic representation
ctfidf_model = ClassTfidfTransformer(reduce_frequent_words=True)

representation_model = KeyBERTInspired(nr_repr_docs=15,random_state=100)

# All steps together
topic_model = BERTopic(
        embedding_model=embedding_model,          
        umap_model=umap_model,                   
        hdbscan_model=cluster_model,              
        vectorizer_model=vectorizer_model,       
        ctfidf_model=ctfidf_model,
        calculate_probabilities=True,
        representation_model=representation_model,
        nr_topics="auto",
        verbose=True)

for data in dataset:    
    topic_model.partial_fit(data)  
    topics.extend(topic_model.topics_)      

# Update model topics attribute
topic_model.topics_ = topics

# Save the model            
topic_model.save(model_safatensors_path,  serialization="safetensors", save_ctfidf=True,
                         save_embedding_model="sentence-transformers/all-MiniLM-L6-v2")

Additionally, here is the backtrace;

Traceback (most recent call last):
  File "/Desktop/projects/app/runner_model.py", line 179, in <module>
    model()
  File "/.cache/pypoetry/virtualenvs/DQ6JMim6-py3.10/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/.cache/pypoetry/virtualenvs/DQ6JMim6-py3.10/lib/python3.10/site-packages/click/core.py", line 1078, in main
    rv = self.invoke(ctx)
  File "/.cache/pypoetry/virtualenvs/DQ6JMim6-py3.10/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/.cache/pypoetry/virtualenvs/DQ6JMim6-py3.10/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/.cache/pypoetry/virtualenvs/DQ6JMim6-py3.10/lib/python3.10/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/Desktop/projects/app/runner_model.py", line 133, in model_with_bert_topic
    use_mmr=usemmr,use_keybert=usekeybert).model()
  File "/Desktop/projects/app/app/nlp_engine/use/__init__.py", line 142, in model
    self.online_training(WrappedRiverClusterAlgo(cluster.CluStream()))
  File "/Desktop/projects/app/app/nlp_engine/use/__init__.py", line 204, in online_training
    topic_model.save(model_safatensors_path,
  File "/.cache/pypoetry/virtualenvs/DQ6JMim6-py3.10/lib/python3.10/site-packages/bertopic/_bertopic.py", line 2963, in save
    save_utils.save_ctfidf_config(model=self, path=save_directory / 'ctfidf_config.json')
  File "/.cache/pypoetry/virtualenvs/DQ6JMim6-py3.10/lib/python3.10/site-packages/bertopic/_save_utils.py", line 350, in save_ctfidf_config
    del cv_params["tokenizer"], cv_params["preprocessor"], cv_params["dtype"]
KeyError: 'tokenizer'
donottakemyusername commented 1 year ago

Yeah I think my code is similar. The problem is for our model countvectorizer, there is no parameters such as "tokenizer" or "preprocessor". When I called model.vectorizer_model.get_params() and it only contains 2 parameters: {'decay': 0.1, 'delete_min_df': None}. So when the save_ctfidf_configfunction calls del cv_params["tokenizer"], cv_params["preprocessor"], cv_params["dtype"] without checking, it returns a key error. I am not entirely sure the point of deleting these parameters without checking, I am just going to try to remove those lines and see if things work properly. But in the meantime if you can let us know if there is anything we can do it would be really helpful. Thank you!

Just an update, I feel like the serialization technique does not work for incremental learning techniques which use OnlineCountVectorizer. It only works for regular CountVectorizer. Please correct me if I am wrong.

MaartenGr commented 1 year ago

I think this is an issue with OnlineCountVectorizer not properly inheriting everything from its base class CountVectorizer. Those lines should not be removed since you would not be able to re-create the vectorizer. I believe two things need to change: