MaartenGr / BERTopic

Leveraging BERT and c-TF-IDF to create easily interpretable topics.
https://maartengr.github.io/BERTopic/
MIT License
6.19k stars 765 forks source link

Array mismatch when try to fit new data #2037

Open Kotik001 opened 5 months ago

Kotik001 commented 5 months ago

Hello! I'm very thankful for your tool and library it's so much awesome!

I should apologise for my basic (stupid) question in advance, however i can't solve it.

First, i'm making model with .fit on pandas DF with 20293 rows (texts)

from bertopic.representation import MaximalMarginalRelevance
representation_model = MaximalMarginalRelevance(diversity=0.4)

topic_model = BERTopic(
  embedding_model=sentence_model,
  umap_model=umap_model,
  hdbscan_model=hdbscan_model,
  vectorizer_model=vectorizer_model,
  ctfidf_model=ctfidf_model,
  representation_model=representation_model,
  top_n_words=6,
  verbose=True,
)

#topics, probs = topic_model.fit_transform(data['filtered_words'],embeddings=embeddings)

Then i make auto nr to 15 topics and reduce outliers with c-TF-IDF strategy.

I have resulting df with all document assigned to topics: result=topic_model.get_document_info(data['filtered_words']) result_df = data.join(result, how='inner')

Then i save model to use it with new data: topic_model.save("testmodel", serialization="pytorch", save_ctfidf=True, save_embedding_model=sentence_model)

So now i'm trying to get document info only for new data

sentence_model = SentenceTransformer("cointegrated/LaBSE-en-ru")
embeddings = sentence_model.encode(data['filtered_words'], show_progress_bar=True)
topic_model11 = BERTopic.load("testmodel")
new_topics, new_probs = topic_model11.transform(data['filtered_words'],embeddings=embeddings) #//here's new data is used
new_topics
document_info = topic_model11.get_document_info(combined_df['filtered_words']) #// here's combined df is just new data for example of problem

And error

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[188], line 1
----> 1 document_info = topic_model11.get_document_info(combined_df['filtered_words'])

File ~\anaconda3\envs\test_env3\lib\site-packages\bertopic\_bertopic.py:1633, in BERTopic.get_document_info(self, docs, df, metadata)
   1631     document_info["Topic"] = self.topics_
   1632 else:
-> 1633     document_info = pd.DataFrame({"Document": docs, "Topic": self.topics_})
   1635 # Add topic info through `.get_topic_info()`
   1636 topic_info = self.get_topic_info().drop("Count", axis=1)

File ~\anaconda3\envs\test_env3\lib\site-packages\pandas\core\frame.py:733, in DataFrame.__init__(self, data, index, columns, dtype, copy)
    727     mgr = self._init_mgr(
    728         data, axes={"index": index, "columns": columns}, dtype=dtype, copy=copy
    729     )
    731 elif isinstance(data, dict):
    732     # GH#38939 de facto copy defaults to False only in non-dict cases
--> 733     mgr = dict_to_mgr(data, index, columns, dtype=dtype, copy=copy, typ=manager)
    734 elif isinstance(data, ma.MaskedArray):
    735     from numpy.ma import mrecords

File ~\anaconda3\envs\test_env3\lib\site-packages\pandas\core\internals\construction.py:503, in dict_to_mgr(data, index, columns, dtype, typ, copy)
    499     else:
    500         # dtype check to exclude e.g. range objects, scalars
    501         arrays = [x.copy() if hasattr(x, "dtype") else x for x in arrays]
--> 503 return arrays_to_mgr(arrays, columns, index, dtype=dtype, typ=typ, consolidate=copy)

File ~\anaconda3\envs\test_env3\lib\site-packages\pandas\core\internals\construction.py:114, in arrays_to_mgr(arrays, columns, index, dtype, verify_integrity, typ, consolidate)
    111 if verify_integrity:
    112     # figure out the index, if necessary
    113     if index is None:
--> 114         index = _extract_index(arrays)
    115     else:
    116         index = ensure_index(index)

File ~\anaconda3\envs\test_env3\lib\site-packages\pandas\core\internals\construction.py:690, in _extract_index(data)
    685     if lengths[0] != len(index):
    686         msg = (
    687             f"array length {lengths[0]} does not match index "
    688             f"length {len(index)}"
    689         )
--> 690         raise ValueError(msg)
    691 else:
    692     index = default_index(lengths[0])

ValueError: array length 20293 does not match index length 111

I figured it out that if i concat new data with old data and then use .get_document_info, it's ok except there's somehow one excess row which i delete from new data so that would not lead to mismatch error. Let me know if you need more info or my whole data/ code. Thanks!

piplist_toshare.txt

MaartenGr commented 5 months ago

The .get_document_info function is only for the documents on which you fitted the data and currently do not work with documents that were not part of the fitting process. You would have to manually create a similar dataframe yourself.