MaartenGr / BERTopic

Leveraging BERT and c-TF-IDF to create easily interpretable topics.
https://maartengr.github.io/BERTopic/
MIT License
6.16k stars 764 forks source link

bug with custom_labels in _hierarchy.py #1503

Open rcprati opened 1 year ago

rcprati commented 1 year ago

Hi,

it seems there is a bug in _hierarchy.py when using aspects as custom labels. In line 151, fig.layout[axis]["ticktext"] returns a list of strings, but topics are indexed by ints in topic_model.topicaspects. I did a local fix by introducing a type cast in line 151 to

    new_labels = [[[str(x), None]] + topic_model.topic_aspects_[custom_labels][int(x)] for x in fig.layout[axis]["ticktext"]]
MaartenGr commented 1 year ago

Could you show a full example of the bug? It is difficult for me to reproduce or provide support without it.

rcprati commented 1 year ago

Sure, here it is:

from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from bertopic.representation import MaximalMarginalRelevance

from sklearn.datasets import fetch_20newsgroups

docs = fetch_20newsgroups(subset='test',  remove=('headers', 'footers', 'quotes'))['data']

# Additional ways of representing a topic
aspect_model1 = MaximalMarginalRelevance()
aspect_model2 = KeyBERTInspired()

# Add all models together to be run in a single `fit`
representation_model = {
   "Aspect1":  aspect_model1,
   "Aspect2":  aspect_model2 
}
topic_model = BERTopic(representation_model=representation_model,verbose=True).fit(docs)

hierarchical_topics = topic_model.hierarchical_topics(docs)

topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics,custom_labels="Aspect1")

and I get the error:

KeyError                                  Traceback (most recent call last)

[<ipython-input-3-0e4010dde1ae>](https://localhost:8080/#) in <cell line: 23>()
     21 hierarchical_topics = topic_model.hierarchical_topics(docs)
     22 
---> 23 topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics,custom_labels="Aspect1")

2 frames

[/usr/local/lib/python3.10/dist-packages/bertopic/plotting/_hierarchy.py](https://localhost:8080/#) in <listcomp>(.0)
    149     axis = "yaxis" if orientation == "left" else "xaxis"
    150     if isinstance(custom_labels, str):
--> 151         new_labels = [[[str(x), None]] + topic_model.topic_aspects_[custom_labels][x] for x in fig.layout[axis]["ticktext"]]
    152         new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels]
    153         new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]

KeyError: '98'
MaartenGr commented 1 year ago

Yep, can definitely reproduce this issue, thanks! If you want, a PR would be much appreciated with #1504 if you have the time 😄 Otherwise, I have no problem creating a PR myself. Either way, thanks for sharing the issue and proposing the fix!