Open rcprati opened 1 year ago
Could you show a full example of the bug? It is difficult for me to reproduce or provide support without it.
Sure, here it is:
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from bertopic.representation import MaximalMarginalRelevance
from sklearn.datasets import fetch_20newsgroups
docs = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'))['data']
# Additional ways of representing a topic
aspect_model1 = MaximalMarginalRelevance()
aspect_model2 = KeyBERTInspired()
# Add all models together to be run in a single `fit`
representation_model = {
"Aspect1": aspect_model1,
"Aspect2": aspect_model2
}
topic_model = BERTopic(representation_model=representation_model,verbose=True).fit(docs)
hierarchical_topics = topic_model.hierarchical_topics(docs)
topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics,custom_labels="Aspect1")
and I get the error:
KeyError Traceback (most recent call last)
[<ipython-input-3-0e4010dde1ae>](https://localhost:8080/#) in <cell line: 23>()
21 hierarchical_topics = topic_model.hierarchical_topics(docs)
22
---> 23 topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics,custom_labels="Aspect1")
2 frames
[/usr/local/lib/python3.10/dist-packages/bertopic/plotting/_hierarchy.py](https://localhost:8080/#) in <listcomp>(.0)
149 axis = "yaxis" if orientation == "left" else "xaxis"
150 if isinstance(custom_labels, str):
--> 151 new_labels = [[[str(x), None]] + topic_model.topic_aspects_[custom_labels][x] for x in fig.layout[axis]["ticktext"]]
152 new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels]
153 new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
KeyError: '98'
Yep, can definitely reproduce this issue, thanks! If you want, a PR would be much appreciated with #1504 if you have the time 😄 Otherwise, I have no problem creating a PR myself. Either way, thanks for sharing the issue and proposing the fix!
Hi,
it seems there is a bug in _hierarchy.py when using aspects as custom labels. In line 151, fig.layout[axis]["ticktext"] returns a list of strings, but topics are indexed by ints in topic_model.topicaspects. I did a local fix by introducing a type cast in line 151 to