MaartenGr / BERTopic

Leveraging BERT and c-TF-IDF to create easily interpretable topics.
https://maartengr.github.io/BERTopic/
MIT License
6.19k stars 765 forks source link

visualize_hierarchy mix up of labels / hover over #923

Closed wrighematthey closed 1 year ago

wrighematthey commented 1 year ago

I've found that for some datasets I'm getting different hover over text from the labels when using the visualize_hierarchy function. Tracking up the tree, the hover over labels seem to be in the correct order but the axis labels don't match at the lowest level. I've tried both with and without updating the topic names and this does not appear to make a difference. I think that the two lists are somehow being ordered differently but I can't track down where this occurs.

MaartenGr commented 1 year ago

Thanks for sharing this. Can you create a reproducible example of this? Also, can you share your code, perhaps there is something going on there.

wrighematthey commented 1 year ago

I've just tried running this with the newsgroup example from the docs but with a different model and can see the same effect. I used the 'all-MiniLM-L12-v2' model from sentence transformers. I don't seem to be able to upload the html version with the hover over tips but an example can be seen in the attached image. image

wrighematthey commented 1 year ago

Example code that produces output with described error

from bertopic import BERTopic
from sklearn.datasets import fetch_20newsgroups
from sentence_transformers import SentenceTransformer
from umap import UMAP

docs = fetch_20newsgroups(subset='all', remove=('headers','footers','quotes'))['data']

models = {'small':'all-MiniLM-L6-v2',
          'medium':'all-MiniLM-L12-v2',
          'large':'all-mpnet-base-v2'}

sentence_model = SentenceTransformer(models['medium'])
topic_model = BERTopic(embedding_model=sentence_model)
embeddings = sentence_model.encode(docs,show_progress_bar=True)

from sklearn.feature_extraction.text import CountVectorizer
topics,probs = topic_model.fit_transform(docs, embeddings=embeddings)
vectorizer_model = CountVectorizer(stop_words="english", ngram_range=(1, 5))
topic_model.update_topics(docs, vectorizer_model=vectorizer_model)

reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)

doc_plot = topic_model.visualize_documents(docs,reduced_embeddings=reduced_embeddings)
doc_plot.show()

hierarchical_topics = topic_model.hierarchical_topics(docs)
htm_plot = topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics,orientation='bottom', custom_labels=False)
htm_plot.show()
htm_plot.write_html('news_example.html')
MaartenGr commented 1 year ago

Strange, I'll have to test this a bit further but it might be a result of one of the dependencies being updated. Could you do a pip freeze and share the packages that you currently have in your environment?

hbeelee commented 1 year ago

@MaartenGr I'm having the exactly same issue with .visualize_hierarchy and .get_topic_tree. The keywords of the topic on the graph (as a result of .visualize_hierarchy) and the keywords that appears when I hover over the labels do not mach - as @wrighematthey explained above. In addition, the hierarchy that appears on .visualize_hierarchy and that on .get_topic_tree are also totally different.

FYI, here's my pip freeze:

absl-py==1.3.0
aeppl==0.0.33
aesara==2.7.9
aiohttp==3.8.3
aiosignal==1.3.1
alabaster==0.7.12
albumentations==1.2.1
altair==4.2.0
appdirs==1.4.4
arviz==0.12.1
astor==0.8.1
astropy==4.3.1
astunparse==1.6.3
async-timeout==4.0.2
atari-py==0.2.9
atomicwrites==1.4.1
attrs==22.2.0
audioread==3.0.0
autograd==1.5
Babel==2.11.0
backcall==0.2.0
beautifulsoup4==4.6.3
bertopic==0.13.0
bleach==5.0.1
blis==0.7.9
bokeh==2.3.3
branca==0.6.0
bs4==0.0.1
CacheControl==0.12.11
cachetools==5.2.1
catalogue==2.0.8
certifi==2022.12.7
cffi==1.15.1
cftime==1.6.2
chardet==4.0.0
charset-normalizer==2.1.1
click==7.1.2
clikit==0.6.2
cloudpickle==2.2.0
cmake==3.22.6
cmdstanpy==1.0.8
colorcet==3.0.1
colorlover==0.3.0
community==1.0.0b1
confection==0.0.3
cons==0.4.5
contextlib2==0.5.5
convertdate==2.4.0
crashtest==0.3.1
crcmod==1.7
cufflinks==0.17.3
cvxopt==1.3.0
cvxpy==1.2.3
cycler==0.11.0
cymem==2.0.7
Cython==0.29.33
daft==0.0.4
dask==2022.2.1
datascience==0.17.5
db-dtypes==1.0.5
dbus-python==1.2.16
debugpy==1.0.0
decorator==4.4.2
defusedxml==0.7.1
descartes==1.1.0
dill==0.3.6
distributed==2022.2.1
dlib==19.24.0
dm-tree==0.1.8
dnspython==2.2.1
docutils==0.16
dopamine-rl==1.0.5
earthengine-api==0.1.335
easydict==1.10
ecos==2.0.12
editdistance==0.5.3
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.4.1/en_core_web_sm-3.4.1-py3-none-any.whl
entrypoints==0.4
ephem==4.1.4
et-xmlfile==1.1.0
etils==1.0.0
etuples==0.3.8
fa2==0.3.5
fastai==2.7.10
fastcore==1.5.27
fastdownload==0.0.7
fastdtw==0.3.4
fastjsonschema==2.16.2
fastprogress==1.0.3
fastrlock==0.8.1
feather-format==0.4.1
filelock==3.9.0
firebase-admin==5.3.0
fix-yahoo-finance==0.0.22
Flask==1.1.4
flatbuffers==1.12
folium==0.12.1.post1
frozenlist==1.3.3
fsspec==2022.11.0
future==0.16.0
gast==0.4.0
GDAL==3.0.4
gdown==4.4.0
gensim==3.6.0
geographiclib==1.52
geopy==1.17.0
gin-config==0.5.0
glob2==0.7
google==2.0.3
google-api-core==2.11.0
google-api-python-client==2.70.0
google-auth==2.16.0
google-auth-httplib2==0.1.0
google-auth-oauthlib==0.4.6
google-cloud-bigquery==3.4.1
google-cloud-bigquery-storage==2.17.0
google-cloud-core==2.3.2
google-cloud-datastore==2.11.1
google-cloud-firestore==2.7.3
google-cloud-language==2.6.1
google-cloud-storage==2.7.0
google-cloud-translate==3.8.4
google-colab @ file:///colabtools/dist/google-colab-1.0.0.tar.gz
google-crc32c==1.5.0
google-pasta==0.2.0
google-resumable-media==2.4.0
googleapis-common-protos==1.58.0
googledrivedownloader==0.4
graphviz==0.10.1
greenlet==2.0.1
grpcio==1.51.1
grpcio-status==1.48.2
gspread==3.4.2
gspread-dataframe==3.0.8
gym==0.25.2
gym-notices==0.0.8
h5py==3.1.0
hdbscan==0.8.29
HeapDict==1.0.1
hijri-converter==2.2.4
holidays==0.18
holoviews==1.14.9
html5lib==1.0.1
httpimport==0.5.18
httplib2==0.17.4
httpstan==4.6.1
huggingface-hub==0.11.1
humanize==0.5.1
hyperopt==0.1.2
idna==2.10
imageio==2.9.0
imagesize==1.4.1
imbalanced-learn==0.8.1
imblearn==0.0
imgaug==0.4.0
importlib-metadata==6.0.0
importlib-resources==5.10.2
imutils==0.5.4
inflect==2.1.0
intel-openmp==2023.0.0
intervaltree==2.1.0
ipykernel==5.3.4
ipython==7.9.0
ipython-genutils==0.2.0
ipython-sql==0.3.9
ipywidgets==7.7.1
itsdangerous==1.1.0
jax==0.3.25
jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.25+cuda11.cudnn805-cp38-cp38-manylinux2014_x86_64.whl
jieba==0.42.1
Jinja2==2.11.3
joblib==1.2.0
jpeg4py==0.1.4
JPype1==1.4.1
jsonschema==4.3.3
jupyter-client==6.1.12
jupyter-console==6.1.0
jupyter_core==5.1.3
jupyterlab-widgets==3.0.5
kaggle==1.5.12
kapre==0.3.7
keras==2.9.0
Keras-Preprocessing==1.1.2
keras-vis==0.4.1
kiwisolver==1.4.4
konlpy==0.6.0
korean-lunar-calendar==0.3.1
langcodes==3.3.0
libclang==15.0.6.1
librosa==0.8.1
lightgbm==2.2.3
llvmlite==0.39.1
lmdb==0.99
locket==1.0.0
logical-unification==0.4.5
LunarCalendar==0.0.9
lxml==4.9.2
Markdown==3.4.1
MarkupSafe==2.0.1
marshmallow==3.19.0
matplotlib==3.2.2
matplotlib-venn==0.11.7
mecab-python===0.996-ko-0.9.2
miniKanren==1.0.3
missingno==0.5.1
mistune==0.8.4
mizani==0.7.3
mkl==2019.0
mlxtend==0.14.0
more-itertools==9.0.0
moviepy==0.2.3.5
mpmath==1.2.1
msgpack==1.0.4
multidict==6.0.4
multipledispatch==0.6.0
multitasking==0.0.11
murmurhash==1.0.9
music21==5.5.0
natsort==5.5.0
nbconvert==5.6.1
nbformat==5.7.1
netCDF4==1.6.2
networkx==3.0
nibabel==3.0.2
nltk==3.7
notebook==5.7.16
numba==0.56.4
numexpr==2.8.4
numpy==1.21.6
oauth2client==4.1.3
oauthlib==3.2.2
okgrade==0.4.3
opencv-contrib-python==4.6.0.66
opencv-python==4.6.0.66
opencv-python-headless==4.7.0.68
openpyxl==3.0.10
opt-einsum==3.3.0
osqp==0.6.2.post0
packaging==21.3
palettable==3.3.0
pandas==1.3.5
pandas-datareader==0.9.0
pandas-gbq==0.17.9
pandas-profiling==1.4.1
pandocfilters==1.5.0
panel==0.12.1
param==1.12.3
parso==0.8.3
partd==1.3.0
pastel==0.2.1
pathlib==1.0.1
pathy==0.10.1
patsy==0.5.3
pep517==0.13.0
pexpect==4.8.0
pickleshare==0.7.5
Pillow==7.1.2
pip-tools==6.6.2
platformdirs==2.6.2
plotly==5.5.0
plotnine==0.8.0
pluggy==0.7.1
pooch==1.6.0
portpicker==1.3.9
prefetch-generator==1.0.3
preshed==3.0.8
prettytable==3.6.0
progressbar2==3.38.0
prometheus-client==0.15.0
promise==2.3
prompt-toolkit==2.0.10
prophet==1.1.1
proto-plus==1.22.2
protobuf==3.19.6
psutil==5.4.8
psycopg2==2.9.5
ptyprocess==0.7.0
py==1.11.0
pyarrow==9.0.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycocotools==2.0.6
pycparser==2.21
pyct==0.4.8
pydantic==1.10.4
pydata-google-auth==1.5.0
pydot==1.3.0
pydot-ng==2.0.0
pydotplus==2.0.2
PyDrive==1.3.1
pyemd==0.5.1
pyerfa==2.0.0.1
Pygments==2.6.1
PyGObject==3.36.0
pylev==1.4.0
pymc==4.1.4
PyMeeus==0.5.12
pymongo==4.3.3
pymystem3==0.2.0
pynndescent==0.5.8
PyOpenGL==3.1.6
pyparsing==3.0.9
pyrsistent==0.19.3
pysimdjson==3.2.0
PySocks==1.7.1
pystan==3.3.0
pytest==3.6.4
python-apt==2.0.1
python-dateutil==2.8.2
python-louvain==0.16
python-slugify==7.0.0
python-utils==3.4.5
pytz==2022.7
pyviz-comms==2.2.1
PyWavelets==1.4.1
PyYAML==5.4.1
pyzmq==23.2.1
qdldl==0.1.5.post2
qudida==0.0.4
regex==2022.6.2
requests==2.25.1
requests-oauthlib==1.3.1
requests-unixsocket==0.2.0
resampy==0.4.2
rpy2==3.5.5
rsa==4.9
scikit-image==0.18.3
scikit-learn==1.0.2
scipy==1.7.3
screen-resolution-extra==0.0.0
scs==3.2.2
seaborn==0.11.2
Send2Trash==1.8.0
sentence-transformers==2.2.2
sentencepiece==0.1.97
setuptools-git==1.2
shapely==2.0.0
six==1.15.0
sklearn-pandas==1.8.0
smart-open==6.3.0
snowballstemmer==2.2.0
sortedcontainers==2.4.0
soundfile==0.11.0
soynlp==0.0.493
spacy==3.4.4
spacy-legacy==3.0.11
spacy-loggers==1.0.4
Sphinx==3.5.4
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
sphinxcontrib.applehelp==1.0.3
SQLAlchemy==1.4.46
sqlparse==0.4.3
srsly==2.4.5
statsmodels==0.12.2
sympy==1.7.1
tables==3.7.0
tabulate==0.8.10
tblib==1.7.0
tenacity==8.1.0
tensorboard==2.9.1
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.9.2
tensorflow-datasets==4.8.1
tensorflow-estimator==2.9.0
tensorflow-gcs-config==2.9.1
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.29.0
tensorflow-metadata==1.12.0
tensorflow-probability==0.17.0
termcolor==2.2.0
terminado==0.13.3
testpath==0.6.0
text-unidecode==1.3
textblob==0.15.3
thinc==8.1.6
threadpoolctl==3.1.0
tifffile==2022.10.10
tokenizers==0.13.2
toml==0.10.2
tomli==2.0.1
toolz==0.12.0
torch @ https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl
torchaudio @ https://download.pytorch.org/whl/cu116/torchaudio-0.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl
torchsummary==1.5.1
torchtext==0.14.1
torchvision @ https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp38-cp38-linux_x86_64.whl
tornado==6.0.4
tqdm==4.64.1
traitlets==5.7.1
transformers==4.25.1
tweepy==3.10.0
typeguard==2.7.1
typer==0.7.0
typing_extensions==4.4.0
tzlocal==1.5.1
umap-learn==0.5.3
uritemplate==4.1.1
urllib3==1.24.3
vega-datasets==0.9.0
wasabi==0.10.1
wcwidth==0.2.5
webargs==8.2.0
webencodings==0.5.1
Werkzeug==1.0.1
widgetsnbextension==3.6.1
wordcloud==1.8.2.2
wrapt==1.14.1
xarray==2022.12.0
xarray-einstats==0.4.0
xgboost==0.90
xkit==0.0.0
xlrd==1.2.0
xlwt==1.3.0
yarl==1.8.2
yellowbrick==1.5
zict==2.2.0
zipp==3.11.0
MaartenGr commented 1 year ago

It took a while to figure out and I should do a bit more testing but I believe the visualization should work if you remove the following lines:

https://github.com/MaartenGr/BERTopic/blob/06dcd47a019854185a39146ed693d16dc3a27651/bertopic/plotting/_hierarchy.py#L232-L234

Hopefully, this should serve as a quick fix for now. I might change/update the plotting backend at some point so this might just be a temporary fix but I am not sure yet if I want to switch over to something like Bokeh.

hbeelee commented 1 year ago

Thank you very much for your feedbacks @MaartenGr, but it seems the problem persists. Would you please look into it anytime soon? I don't mean to rush you at all, though.

MaartenGr commented 1 year ago

@hbeelee Just to be sure, did you manually remove the lines of the code in _hierarchy.py and used the updated visualization function as a result? For me, this process seems to be working.

sharma-n commented 1 year ago

Hi ! I had the same issue on my side as well, and going from 0.13.0 to 0.12.0 seems to fix it for me (even though manually doing the change in _hierarchy.py doesn't solve the issue). For my case, it wasn't just the labels that were wrong, but the hierarchy visualized did not match the one returned by topic_model.hierarchical_topics() Hope this helps!

MaartenGr commented 1 year ago

@sharma-n Could you create a reproducible example? With manually doing the change I mentioned, I cannot seem to reproduce the issue which makes it difficult to fix.

sharma-n commented 1 year ago

Hey @MaartenGr , I was using a private dataset so couldn't add that. However, I found another public one where the issue can be reproduced (from Kaggle, and have attached the CSV file Corona_NLP_test.csv.

I use the following code to process the data:

from sentence_transformers import SentenceTransformer
from umap import UMAP
from hdbscan import HDBSCAN
from bertopic import BERTopic

data= pd.read_csv('Corona_NLP_test.csv')

sentence_model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
embeddings = sentence_model.encode(data.OriginalTweet, show_progress_bar=False)
embeddings = UMAP(n_neighbors=15, n_components=5, min_dist=0.0, metric='cosine', random_state=42).fit_transform(embeddings)
hdbscan_model = HDBSCAN(min_cluster_size=341, min_samples=3)

topic_model = BERTopic(hdbscan_model=hdbscan_model ,n_gram_range=(1,3), min_topic_size=341)
topics, probs = topic_model.fit_transform(data.OriginalTweet, embeddings)
hierarchical_topics = topic_model.hierarchical_topics(data.OriginalTweet)
topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics, orientation='left', color_threshold=0)

Without any changes to version 0.13.0, I get the following output (I used custom labels, if that is of importance): 13_nochange Based on the hovertext, you can see it doesn't correspond to what it should be (matching the top topic better than the current topic) If I do the changes you mentioned, I get different results (that are still not correct): 13_afterchange Finally, with version 0.12.0 I do get the expected results: 12

MaartenGr commented 1 year ago

Thanks for sharing the example! I think there needs to be one additional chance that might solve the issue, namely additionally removing the following lines:

https://github.com/MaartenGr/BERTopic/blob/d665d3f8d8c7c1736dc82b1df8839ced56a2adb6/bertopic/_bertopic.py#L869-L871

Based on what you mentioned, the main thing that changed between v0.12 and v0.13 was the introduction of a 1-D condensed distance matrix as that is what scipy expects as an input. However, and I have no clue why, this is what messes up the order in the visualization.

zilch42 commented 1 year ago

Hi Maarten, removing those two lines from both _bertopic.py and _hierarchy.py worked for me with regards to #1063. I have tested both visualize_hierarchy and visualize_hierarchical_documents and both are looking sensible.

elashrry commented 1 year ago

Hello @MaartenGr , I have the same issue. Removing the two lines you mentioned from both _bertopic.py and _hierarchy.py will solve the problem of matching labels and hover annotations. However, the clustering itself might not be correct. I think it lies in these two lines from Plotly's API.

d = distfun(X)
Z = linkagefun(d)

SciPy's API mentions the first argument

The input y may be either a 1-D condensed distance matrix or a 2-D array of observation vectors.

So, in your code, you assume the distance function returns a symmetric 2-D array of pairwise distances while Plotly assumes it returns a 1-D array.

So I think the solution could be either

  1. to make sure the distance_function returns a 1-D array before using Plotly's create_dendrogram function in the plotting module. https://github.com/MaartenGr/BERTopic/blob/d665d3f8d8c7c1736dc82b1df8839ced56a2adb6/bertopic/plotting/_hierarchy.py#L126

  2. Or, add in the docstring that distance_function should return a 1-D array, and adapt the code for that, i.e. remove the lines you mentioned in hierarchical_topics _bertopic.py and also in _get_annotations in _hierarchy.py.

I would be happy to help, just let me know :)

MaartenGr commented 1 year ago

@elashrry Thanks for digging into this. I think another option would be to check for 2D distances and convert them like I did here using scipy.spatial.distance.squareform might also fix this issue. However, that would require running a test of the distance function to see whether it returns 2D or condensed 1D shapes.

The reason for suggesting the above is that the "hard" work is being done in BERTopic and that the user does not have to figure out what to pass as both will then be accepted.

elashrry commented 1 year ago

I am not sure this solves the real issue here. I think you are already doing that wherever you can, but the issue is that plotly doesn't do the same thing, so we end up with two different clustering. One from hierarchical_topicsin _bertopic.py (and it is the same one in _get_annotations in _hierarchy.py) and another one from plotly's create_dendrogram.

If you want minimal change and it won't affect your API, you can add these lines before calling plotly's create_dendrogram. I just tested it and it works.


    # convret distance_function so that it returns a 1-D array
    def distfun(X):
        X = distance_function(X)
        np.fill_diagonal(X, 0)
        return squareform(X)

and to be precise, I am assuming the correct clustering is the one we get from the follwing code:

embeddings = topic_model.c_tf_idf_ # no filtering just for simplicity here

distance_function = lambda X: 1 - cosine_similarity(X)
linkage_function = lambda x: hierarchy.linkage(
    X, 'ward', optimal_ordering=True)

fig, ax = plt.subplots(figsize=(10, 15))
# compute distances
X = distance_function(embeddings)
# Make sure it is the 1-D condensed distance matrix with zeros on the diagonal
np.fill_diagonal(X, 0)
X = squareform(X)
# linkage
Z = hierarchy.linkage(X, 'ward', optimal_ordering=True)
dn = hierarchy.dendrogram(Z, orientation='right', ax=ax)
MaartenGr commented 1 year ago

I am not sure this solves the real issue here. I think you are already doing that wherever you can, but the issue is that plotly doesn't do the same thing, so we end up with two different clustering.

I was actually suggesting that the distance function that plotly gets is adapted based on whether its output is 1D or 2D. In other words, it would be very similar to what you suggest here:

# Convert distance_function so that it returns a 1-D array
def distfun(X):
    X = distance_function(X)
    np.fill_diagonal(X, 0)
    return squareform(X)

but only if the original distance function outputs a 2D matrix. Otherwise, there might be unintended consequences if somebody already passes a 1D condensed distance matrix like as is typically the output of scipy.spatial.distance.pdist. This would then also need to be implemented for .hierarchical_topics and ._get_annotations.

elashrry commented 1 year ago

Yes, that would be more robust. Let me know if I can help

MaartenGr commented 1 year ago

If you have the time, then a PR would be appreciated! Otherwise, I might find some time in the upcoming weeks.