Closed wrighematthey closed 1 year ago
Thanks for sharing this. Can you create a reproducible example of this? Also, can you share your code, perhaps there is something going on there.
I've just tried running this with the newsgroup example from the docs but with a different model and can see the same effect. I used the 'all-MiniLM-L12-v2' model from sentence transformers. I don't seem to be able to upload the html version with the hover over tips but an example can be seen in the attached image.
Example code that produces output with described error
from bertopic import BERTopic
from sklearn.datasets import fetch_20newsgroups
from sentence_transformers import SentenceTransformer
from umap import UMAP
docs = fetch_20newsgroups(subset='all', remove=('headers','footers','quotes'))['data']
models = {'small':'all-MiniLM-L6-v2',
'medium':'all-MiniLM-L12-v2',
'large':'all-mpnet-base-v2'}
sentence_model = SentenceTransformer(models['medium'])
topic_model = BERTopic(embedding_model=sentence_model)
embeddings = sentence_model.encode(docs,show_progress_bar=True)
from sklearn.feature_extraction.text import CountVectorizer
topics,probs = topic_model.fit_transform(docs, embeddings=embeddings)
vectorizer_model = CountVectorizer(stop_words="english", ngram_range=(1, 5))
topic_model.update_topics(docs, vectorizer_model=vectorizer_model)
reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)
doc_plot = topic_model.visualize_documents(docs,reduced_embeddings=reduced_embeddings)
doc_plot.show()
hierarchical_topics = topic_model.hierarchical_topics(docs)
htm_plot = topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics,orientation='bottom', custom_labels=False)
htm_plot.show()
htm_plot.write_html('news_example.html')
Strange, I'll have to test this a bit further but it might be a result of one of the dependencies being updated. Could you do a pip freeze and share the packages that you currently have in your environment?
@MaartenGr I'm having the exactly same issue with .visualize_hierarchy and .get_topic_tree. The keywords of the topic on the graph (as a result of .visualize_hierarchy) and the keywords that appears when I hover over the labels do not mach - as @wrighematthey explained above. In addition, the hierarchy that appears on .visualize_hierarchy and that on .get_topic_tree are also totally different.
FYI, here's my pip freeze:
absl-py==1.3.0
aeppl==0.0.33
aesara==2.7.9
aiohttp==3.8.3
aiosignal==1.3.1
alabaster==0.7.12
albumentations==1.2.1
altair==4.2.0
appdirs==1.4.4
arviz==0.12.1
astor==0.8.1
astropy==4.3.1
astunparse==1.6.3
async-timeout==4.0.2
atari-py==0.2.9
atomicwrites==1.4.1
attrs==22.2.0
audioread==3.0.0
autograd==1.5
Babel==2.11.0
backcall==0.2.0
beautifulsoup4==4.6.3
bertopic==0.13.0
bleach==5.0.1
blis==0.7.9
bokeh==2.3.3
branca==0.6.0
bs4==0.0.1
CacheControl==0.12.11
cachetools==5.2.1
catalogue==2.0.8
certifi==2022.12.7
cffi==1.15.1
cftime==1.6.2
chardet==4.0.0
charset-normalizer==2.1.1
click==7.1.2
clikit==0.6.2
cloudpickle==2.2.0
cmake==3.22.6
cmdstanpy==1.0.8
colorcet==3.0.1
colorlover==0.3.0
community==1.0.0b1
confection==0.0.3
cons==0.4.5
contextlib2==0.5.5
convertdate==2.4.0
crashtest==0.3.1
crcmod==1.7
cufflinks==0.17.3
cvxopt==1.3.0
cvxpy==1.2.3
cycler==0.11.0
cymem==2.0.7
Cython==0.29.33
daft==0.0.4
dask==2022.2.1
datascience==0.17.5
db-dtypes==1.0.5
dbus-python==1.2.16
debugpy==1.0.0
decorator==4.4.2
defusedxml==0.7.1
descartes==1.1.0
dill==0.3.6
distributed==2022.2.1
dlib==19.24.0
dm-tree==0.1.8
dnspython==2.2.1
docutils==0.16
dopamine-rl==1.0.5
earthengine-api==0.1.335
easydict==1.10
ecos==2.0.12
editdistance==0.5.3
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.4.1/en_core_web_sm-3.4.1-py3-none-any.whl
entrypoints==0.4
ephem==4.1.4
et-xmlfile==1.1.0
etils==1.0.0
etuples==0.3.8
fa2==0.3.5
fastai==2.7.10
fastcore==1.5.27
fastdownload==0.0.7
fastdtw==0.3.4
fastjsonschema==2.16.2
fastprogress==1.0.3
fastrlock==0.8.1
feather-format==0.4.1
filelock==3.9.0
firebase-admin==5.3.0
fix-yahoo-finance==0.0.22
Flask==1.1.4
flatbuffers==1.12
folium==0.12.1.post1
frozenlist==1.3.3
fsspec==2022.11.0
future==0.16.0
gast==0.4.0
GDAL==3.0.4
gdown==4.4.0
gensim==3.6.0
geographiclib==1.52
geopy==1.17.0
gin-config==0.5.0
glob2==0.7
google==2.0.3
google-api-core==2.11.0
google-api-python-client==2.70.0
google-auth==2.16.0
google-auth-httplib2==0.1.0
google-auth-oauthlib==0.4.6
google-cloud-bigquery==3.4.1
google-cloud-bigquery-storage==2.17.0
google-cloud-core==2.3.2
google-cloud-datastore==2.11.1
google-cloud-firestore==2.7.3
google-cloud-language==2.6.1
google-cloud-storage==2.7.0
google-cloud-translate==3.8.4
google-colab @ file:///colabtools/dist/google-colab-1.0.0.tar.gz
google-crc32c==1.5.0
google-pasta==0.2.0
google-resumable-media==2.4.0
googleapis-common-protos==1.58.0
googledrivedownloader==0.4
graphviz==0.10.1
greenlet==2.0.1
grpcio==1.51.1
grpcio-status==1.48.2
gspread==3.4.2
gspread-dataframe==3.0.8
gym==0.25.2
gym-notices==0.0.8
h5py==3.1.0
hdbscan==0.8.29
HeapDict==1.0.1
hijri-converter==2.2.4
holidays==0.18
holoviews==1.14.9
html5lib==1.0.1
httpimport==0.5.18
httplib2==0.17.4
httpstan==4.6.1
huggingface-hub==0.11.1
humanize==0.5.1
hyperopt==0.1.2
idna==2.10
imageio==2.9.0
imagesize==1.4.1
imbalanced-learn==0.8.1
imblearn==0.0
imgaug==0.4.0
importlib-metadata==6.0.0
importlib-resources==5.10.2
imutils==0.5.4
inflect==2.1.0
intel-openmp==2023.0.0
intervaltree==2.1.0
ipykernel==5.3.4
ipython==7.9.0
ipython-genutils==0.2.0
ipython-sql==0.3.9
ipywidgets==7.7.1
itsdangerous==1.1.0
jax==0.3.25
jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.25+cuda11.cudnn805-cp38-cp38-manylinux2014_x86_64.whl
jieba==0.42.1
Jinja2==2.11.3
joblib==1.2.0
jpeg4py==0.1.4
JPype1==1.4.1
jsonschema==4.3.3
jupyter-client==6.1.12
jupyter-console==6.1.0
jupyter_core==5.1.3
jupyterlab-widgets==3.0.5
kaggle==1.5.12
kapre==0.3.7
keras==2.9.0
Keras-Preprocessing==1.1.2
keras-vis==0.4.1
kiwisolver==1.4.4
konlpy==0.6.0
korean-lunar-calendar==0.3.1
langcodes==3.3.0
libclang==15.0.6.1
librosa==0.8.1
lightgbm==2.2.3
llvmlite==0.39.1
lmdb==0.99
locket==1.0.0
logical-unification==0.4.5
LunarCalendar==0.0.9
lxml==4.9.2
Markdown==3.4.1
MarkupSafe==2.0.1
marshmallow==3.19.0
matplotlib==3.2.2
matplotlib-venn==0.11.7
mecab-python===0.996-ko-0.9.2
miniKanren==1.0.3
missingno==0.5.1
mistune==0.8.4
mizani==0.7.3
mkl==2019.0
mlxtend==0.14.0
more-itertools==9.0.0
moviepy==0.2.3.5
mpmath==1.2.1
msgpack==1.0.4
multidict==6.0.4
multipledispatch==0.6.0
multitasking==0.0.11
murmurhash==1.0.9
music21==5.5.0
natsort==5.5.0
nbconvert==5.6.1
nbformat==5.7.1
netCDF4==1.6.2
networkx==3.0
nibabel==3.0.2
nltk==3.7
notebook==5.7.16
numba==0.56.4
numexpr==2.8.4
numpy==1.21.6
oauth2client==4.1.3
oauthlib==3.2.2
okgrade==0.4.3
opencv-contrib-python==4.6.0.66
opencv-python==4.6.0.66
opencv-python-headless==4.7.0.68
openpyxl==3.0.10
opt-einsum==3.3.0
osqp==0.6.2.post0
packaging==21.3
palettable==3.3.0
pandas==1.3.5
pandas-datareader==0.9.0
pandas-gbq==0.17.9
pandas-profiling==1.4.1
pandocfilters==1.5.0
panel==0.12.1
param==1.12.3
parso==0.8.3
partd==1.3.0
pastel==0.2.1
pathlib==1.0.1
pathy==0.10.1
patsy==0.5.3
pep517==0.13.0
pexpect==4.8.0
pickleshare==0.7.5
Pillow==7.1.2
pip-tools==6.6.2
platformdirs==2.6.2
plotly==5.5.0
plotnine==0.8.0
pluggy==0.7.1
pooch==1.6.0
portpicker==1.3.9
prefetch-generator==1.0.3
preshed==3.0.8
prettytable==3.6.0
progressbar2==3.38.0
prometheus-client==0.15.0
promise==2.3
prompt-toolkit==2.0.10
prophet==1.1.1
proto-plus==1.22.2
protobuf==3.19.6
psutil==5.4.8
psycopg2==2.9.5
ptyprocess==0.7.0
py==1.11.0
pyarrow==9.0.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycocotools==2.0.6
pycparser==2.21
pyct==0.4.8
pydantic==1.10.4
pydata-google-auth==1.5.0
pydot==1.3.0
pydot-ng==2.0.0
pydotplus==2.0.2
PyDrive==1.3.1
pyemd==0.5.1
pyerfa==2.0.0.1
Pygments==2.6.1
PyGObject==3.36.0
pylev==1.4.0
pymc==4.1.4
PyMeeus==0.5.12
pymongo==4.3.3
pymystem3==0.2.0
pynndescent==0.5.8
PyOpenGL==3.1.6
pyparsing==3.0.9
pyrsistent==0.19.3
pysimdjson==3.2.0
PySocks==1.7.1
pystan==3.3.0
pytest==3.6.4
python-apt==2.0.1
python-dateutil==2.8.2
python-louvain==0.16
python-slugify==7.0.0
python-utils==3.4.5
pytz==2022.7
pyviz-comms==2.2.1
PyWavelets==1.4.1
PyYAML==5.4.1
pyzmq==23.2.1
qdldl==0.1.5.post2
qudida==0.0.4
regex==2022.6.2
requests==2.25.1
requests-oauthlib==1.3.1
requests-unixsocket==0.2.0
resampy==0.4.2
rpy2==3.5.5
rsa==4.9
scikit-image==0.18.3
scikit-learn==1.0.2
scipy==1.7.3
screen-resolution-extra==0.0.0
scs==3.2.2
seaborn==0.11.2
Send2Trash==1.8.0
sentence-transformers==2.2.2
sentencepiece==0.1.97
setuptools-git==1.2
shapely==2.0.0
six==1.15.0
sklearn-pandas==1.8.0
smart-open==6.3.0
snowballstemmer==2.2.0
sortedcontainers==2.4.0
soundfile==0.11.0
soynlp==0.0.493
spacy==3.4.4
spacy-legacy==3.0.11
spacy-loggers==1.0.4
Sphinx==3.5.4
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
sphinxcontrib.applehelp==1.0.3
SQLAlchemy==1.4.46
sqlparse==0.4.3
srsly==2.4.5
statsmodels==0.12.2
sympy==1.7.1
tables==3.7.0
tabulate==0.8.10
tblib==1.7.0
tenacity==8.1.0
tensorboard==2.9.1
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.9.2
tensorflow-datasets==4.8.1
tensorflow-estimator==2.9.0
tensorflow-gcs-config==2.9.1
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.29.0
tensorflow-metadata==1.12.0
tensorflow-probability==0.17.0
termcolor==2.2.0
terminado==0.13.3
testpath==0.6.0
text-unidecode==1.3
textblob==0.15.3
thinc==8.1.6
threadpoolctl==3.1.0
tifffile==2022.10.10
tokenizers==0.13.2
toml==0.10.2
tomli==2.0.1
toolz==0.12.0
torch @ https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl
torchaudio @ https://download.pytorch.org/whl/cu116/torchaudio-0.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl
torchsummary==1.5.1
torchtext==0.14.1
torchvision @ https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp38-cp38-linux_x86_64.whl
tornado==6.0.4
tqdm==4.64.1
traitlets==5.7.1
transformers==4.25.1
tweepy==3.10.0
typeguard==2.7.1
typer==0.7.0
typing_extensions==4.4.0
tzlocal==1.5.1
umap-learn==0.5.3
uritemplate==4.1.1
urllib3==1.24.3
vega-datasets==0.9.0
wasabi==0.10.1
wcwidth==0.2.5
webargs==8.2.0
webencodings==0.5.1
Werkzeug==1.0.1
widgetsnbextension==3.6.1
wordcloud==1.8.2.2
wrapt==1.14.1
xarray==2022.12.0
xarray-einstats==0.4.0
xgboost==0.90
xkit==0.0.0
xlrd==1.2.0
xlwt==1.3.0
yarl==1.8.2
yellowbrick==1.5
zict==2.2.0
zipp==3.11.0
It took a while to figure out and I should do a bit more testing but I believe the visualization should work if you remove the following lines:
Hopefully, this should serve as a quick fix for now. I might change/update the plotting backend at some point so this might just be a temporary fix but I am not sure yet if I want to switch over to something like Bokeh.
Thank you very much for your feedbacks @MaartenGr, but it seems the problem persists. Would you please look into it anytime soon? I don't mean to rush you at all, though.
@hbeelee Just to be sure, did you manually remove the lines of the code in _hierarchy.py
and used the updated visualization function as a result? For me, this process seems to be working.
Hi !
I had the same issue on my side as well, and going from 0.13.0
to 0.12.0
seems to fix it for me (even though manually doing the change in _hierarchy.py doesn't solve the issue).
For my case, it wasn't just the labels that were wrong, but the hierarchy visualized did not match the one returned by topic_model.hierarchical_topics()
Hope this helps!
@sharma-n Could you create a reproducible example? With manually doing the change I mentioned, I cannot seem to reproduce the issue which makes it difficult to fix.
Hey @MaartenGr , I was using a private dataset so couldn't add that. However, I found another public one where the issue can be reproduced (from Kaggle, and have attached the CSV file Corona_NLP_test.csv.
I use the following code to process the data:
from sentence_transformers import SentenceTransformer
from umap import UMAP
from hdbscan import HDBSCAN
from bertopic import BERTopic
data= pd.read_csv('Corona_NLP_test.csv')
sentence_model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
embeddings = sentence_model.encode(data.OriginalTweet, show_progress_bar=False)
embeddings = UMAP(n_neighbors=15, n_components=5, min_dist=0.0, metric='cosine', random_state=42).fit_transform(embeddings)
hdbscan_model = HDBSCAN(min_cluster_size=341, min_samples=3)
topic_model = BERTopic(hdbscan_model=hdbscan_model ,n_gram_range=(1,3), min_topic_size=341)
topics, probs = topic_model.fit_transform(data.OriginalTweet, embeddings)
hierarchical_topics = topic_model.hierarchical_topics(data.OriginalTweet)
topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics, orientation='left', color_threshold=0)
Without any changes to version 0.13.0
, I get the following output (I used custom labels, if that is of importance):
Based on the hovertext, you can see it doesn't correspond to what it should be (matching the top topic better than the current topic)
If I do the changes you mentioned, I get different results (that are still not correct):
Finally, with version 0.12.0
I do get the expected results:
Thanks for sharing the example! I think there needs to be one additional chance that might solve the issue, namely additionally removing the following lines:
Based on what you mentioned, the main thing that changed between v0.12 and v0.13 was the introduction of a 1-D condensed distance matrix as that is what scipy expects as an input. However, and I have no clue why, this is what messes up the order in the visualization.
Hi Maarten, removing those two lines from both _bertopic.py and _hierarchy.py worked for me with regards to #1063. I have tested both visualize_hierarchy and visualize_hierarchical_documents and both are looking sensible.
Hello @MaartenGr ,
I have the same issue. Removing the two lines you mentioned from both _bertopic.py
and _hierarchy.py
will solve the problem of matching labels and hover annotations. However, the clustering itself might not be correct. I think it lies in these two lines from Plotly's API.
d = distfun(X)
Z = linkagefun(d)
SciPy's API mentions the first argument
The input y may be either a 1-D condensed distance matrix or a 2-D array of observation vectors.
So, in your code, you assume the distance function returns a symmetric 2-D array of pairwise distances while Plotly assumes it returns a 1-D array.
So I think the solution could be either
to make sure the distance_function
returns a 1-D array before using Plotly's create_dendrogram
function in the plotting module. https://github.com/MaartenGr/BERTopic/blob/d665d3f8d8c7c1736dc82b1df8839ced56a2adb6/bertopic/plotting/_hierarchy.py#L126
Or, add in the docstring that distance_function
should return a 1-D array, and adapt the code for that, i.e. remove the lines you mentioned in hierarchical_topics
_bertopic.py
and also in _get_annotations
in _hierarchy.py
.
I would be happy to help, just let me know :)
@elashrry Thanks for digging into this. I think another option would be to check for 2D distances and convert them like I did here using scipy.spatial.distance.squareform
might also fix this issue. However, that would require running a test of the distance function to see whether it returns 2D or condensed 1D shapes.
The reason for suggesting the above is that the "hard" work is being done in BERTopic and that the user does not have to figure out what to pass as both will then be accepted.
I am not sure this solves the real issue here. I think you are already doing that wherever you can, but the issue is that plotly doesn't do the same thing, so we end up with two different clustering. One from hierarchical_topics
in _bertopic.py
(and it is the same one in _get_annotations
in _hierarchy.py
) and another one from plotly's create_dendrogram
.
If you want minimal change and it won't affect your API, you can add these lines before calling plotly's create_dendrogram
. I just tested it and it works.
# convret distance_function so that it returns a 1-D array
def distfun(X):
X = distance_function(X)
np.fill_diagonal(X, 0)
return squareform(X)
and to be precise, I am assuming the correct clustering is the one we get from the follwing code:
embeddings = topic_model.c_tf_idf_ # no filtering just for simplicity here
distance_function = lambda X: 1 - cosine_similarity(X)
linkage_function = lambda x: hierarchy.linkage(
X, 'ward', optimal_ordering=True)
fig, ax = plt.subplots(figsize=(10, 15))
# compute distances
X = distance_function(embeddings)
# Make sure it is the 1-D condensed distance matrix with zeros on the diagonal
np.fill_diagonal(X, 0)
X = squareform(X)
# linkage
Z = hierarchy.linkage(X, 'ward', optimal_ordering=True)
dn = hierarchy.dendrogram(Z, orientation='right', ax=ax)
I am not sure this solves the real issue here. I think you are already doing that wherever you can, but the issue is that plotly doesn't do the same thing, so we end up with two different clustering.
I was actually suggesting that the distance function that plotly gets is adapted based on whether its output is 1D or 2D. In other words, it would be very similar to what you suggest here:
# Convert distance_function so that it returns a 1-D array
def distfun(X):
X = distance_function(X)
np.fill_diagonal(X, 0)
return squareform(X)
but only if the original distance function outputs a 2D matrix. Otherwise, there might be unintended consequences if somebody already passes a 1D condensed distance matrix like as is typically the output of scipy.spatial.distance.pdist
. This would then also need to be implemented for .hierarchical_topics
and ._get_annotations
.
Yes, that would be more robust. Let me know if I can help
If you have the time, then a PR would be appreciated! Otherwise, I might find some time in the upcoming weeks.
I've found that for some datasets I'm getting different hover over text from the labels when using the visualize_hierarchy function. Tracking up the tree, the hover over labels seem to be in the correct order but the axis labels don't match at the lowest level. I've tried both with and without updating the topic names and this does not appear to make a difference. I think that the two lists are somehow being ordered differently but I can't track down where this occurs.