MaartenGr / BERTopic

Leveraging BERT and c-TF-IDF to create easily interpretable topics.
https://maartengr.github.io/BERTopic/
MIT License
6.01k stars 752 forks source link

incorrect result by topic_model.get_topic_info() due to zeroshot_topic_list was set #2102

Closed PhPv closed 1 month ago

PhPv commented 1 month ago

Have you searched existing issues? 🔎

Desribe the bug

We have df with 5 clusters df.csv

and zeroshot_topic_list with 3 clusters zeroshot_topic_list.csv

There are 150 holidays, 100 work and 50 code messages in the dataset. The final result in topic_model.get_topic_info() Name does not match Representation because Name is passed in the sort in which it is set in zeroshot_topic_list (1 code, 2 work, 3 holidays), and Representation is sorted by the number of messages in the cluster from larger to smaller (3 holidays, 2 work, 1 code).

1

Reproduction

import pandas as pd
import os
from bertopic import BERTopic
from sentence_transformers import SentenceTransformer
from umap import UMAP
from bertopic.representation import KeyBERTInspired

root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__)))

representation_model = KeyBERTInspired()

sentence_model = SentenceTransformer(
    model_name_or_path=f"{root_dir}/models/paraphrase-multilingual-mpnet-base-v2"
) 

zeroshot_topic_list = pd.read_csv(
    f"{root_dir}/zeroshot_topic_list.csv",
    sep=";",
    index_col=0,
)

df = pd.read_csv(
    f"{root_dir}/df.csv",
    sep=";",
    index_col=0,
)

zeroshot_topic_list = list(
    zeroshot_topic_list["Name"].str.replace(r"^\d+_", "", regex=True)
)

embeddings = sentence_model.encode(list(df["text"]), show_progress_bar=False)
umap_model = UMAP(
    n_neighbors=15,
    n_components=5,
    min_dist=0.0,
    metric="cosine",
    random_state=42,
)
topic_model = BERTopic(
    language="multilingual",
    umap_model=umap_model,
    embedding_model=sentence_model,
    top_n_words=20,
    n_gram_range=(1, 2),
    min_topic_size=50,
    nr_topics=None,  
    vectorizer_model=None,
    representation_model=representation_model,
    zeroshot_topic_list=zeroshot_topic_list,
    zeroshot_min_similarity=0.7,
)
topics, probs = topic_model.fit_transform(list(df["text"]), embeddings)
print(topic_model.get_topic_info())

BERTopic Version

0.16.3

MaartenGr commented 1 month ago

@ianrandman It seems that the zero-shot topic modeling doesn't properly work in its current state. The topic representations for the zero-shot topics do not match with their respective label if I run the following:

from sentence_transformers import SentenceTransformer
from datasets import load_dataset
from bertopic import BERTopic
from hdbscan import HDBSCAN
from umap import UMAP

# Extract abstracts to train on and corresponding titles
dataset = load_dataset("CShorten/ML-ArXiv-Papers")["train"]
abstracts = dataset["abstract"][:10_000]

# Zero-shot
zeroshot_topic_list = ["differential privacy", "clustering", "topic modeling", "anomaly detection"]

# Pre-calculate embeddings
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = embedding_model.encode(abstracts, show_progress_bar=True)

# Use sub-models
umap_model = UMAP(n_components=5, n_neighbors=15, min_dist=0.0, random_state=42)
hdbscan_model = HDBSCAN(min_samples=5, gen_min_span_tree=True, prediction_data=True)

# Pass the above models to be used in BERTopic
topic_model = BERTopic(
    embedding_model=embedding_model,
    umap_model=umap_model, 
    hdbscan_model=hdbscan_model, 
    verbose=True,
    zeroshot_topic_list=zeroshot_topic_list,
    zeroshot_min_similarity=.5
)
topic_model = topic_model.fit(abstracts, embeddings)

image

When I run topic_model.get_topic(0) for the "differential privacy" topic, I get unrelated representations:

[('matrix', 0.024265323538848613),
 ('rank', 0.022604252685197664),
 ('pca', 0.01504805050826997),
 ('completion', 0.01464524649146218),
 ('low', 0.014005813331418024),
 ('norm', 0.010319682346157882),
 ('entries', 0.008740108085038154),
 ('matrices', 0.008468837267613435),
 ('principal', 0.007673173499110779),
 ('nuclear', 0.0068647403068057425)]

Any idea what could be the issue?

ianrandman commented 1 month ago

I believe I understand the issue you guys are claiming. However, I am struggling to reproduce the issue. When running your (@MaartenGr) code, topic_model.get_topic_info() returns

image

And topic_model.get_topic(0) returns

image

Checking with the other topics also seems to show the name and representation in alignment with what topic_model.get_topic_info() shows. I also ran it a handful of times without showing the issue.

I have made sure I am on 0.16.3.

PhPv commented 1 month ago

@ianrandman So you reproduced the problem) Look at your screenshot on line #3: Name: clustering; representation: matrix, rank.... line #5 Name: anomaly detection; representation: clustering, clusters ... I would call the 5th line the name of the third. Can you run my example? There's a simple dataset with the same words and the problem is more clearly visible

MaartenGr commented 1 month ago

@ianrandman You indeed reproduced (at least part) of the issue. The wrong representation is used for the zero-shot topics. Words like "gradient" and "stochastic" are unlikely for the zero-shot topic "differential privacy". I would expect more words like "privacy" to be in there.

@PhPv I don't think this is an issue with the order but more that the entire representation is used for the wrong topic. For instance, the example you give is not that the representations of clustering and anomaly detection are switched since words like "matrix" and "rank" are also not typical for an "anomaly detection" topic.

As such, I believe that the wrong representations are taken for the zero-shot topic labels and that the cause is not necessarily an inverted sort.

PhPv commented 1 month ago

@MaartenGr I think the representation is correct, but there are no names of topics from the zeroshot list. look at topic_doc_info where each document is assigned a topic name and a representation. the representation is correct there, but the list is no longer there because it just goes in the order in which it was transmitted

ianrandman commented 1 month ago

I believe I found the issue.

In .fit_transform, the topics are remapped on this line: https://github.com/MaartenGr/BERTopic/blob/2353f4c21d74e33e34e30dbae938304bff094792/bertopic/_bertopic.py#L474

In doing so, self.topic_mapper_ is updated, but those updated topic IDs do not make their way into self._topic_id_to_zeroshot_topic_idx. I think this issue did not come up during my testing because the zeroshot topics were always the most frequent, so either order did not change or did not change significantly due to sorting by frequency, unlike with @PhPv's example.

During topic_model.get_topic_info(), the topic_labels_ are queried. The topic IDs for the zeroshot topics must be mapped using self.topic_mapper_. https://github.com/MaartenGr/BERTopic/blob/2353f4c21d74e33e34e30dbae938304bff094792/bertopic/_bertopic.py#L313-L318 becomes

if self._is_zeroshot():
    # Need to correct labels from zero-shot topics
    topic_id_to_zeroshot_label = {
        self.topic_mapper_.get_mappings()[topic_id]: self.zeroshot_topic_list[zeroshot_topic_idx]
        for topic_id, zeroshot_topic_idx in self._topic_id_to_zeroshot_topic_idx.items()
    }

Everywhere that self._topic_id_to_zeroshot_topic_idx is referenced otherwise (in _reduce_to_n_topics) must also have this fix.

However,

we already change self._topic_id_to_zeroshot_topic_idx here: https://github.com/MaartenGr/BERTopic/blob/2353f4c21d74e33e34e30dbae938304bff094792/bertopic/_bertopic.py#L4446 So really, either self._topic_id_to_zeroshot_topic_idx should never change (and let it always be used in conjunction with self.topic_mapper_.get_mappings() when its values are referenced), or anytime self.topic_mapper_.add_mappings is used, self._topic_id_to_zeroshot_topic_idx should get updated as well.

I think to stay true to the purpose of the TopicMapper, the former is a better option. This would mean its usages in _reduce_to_n_topics would have to change along with in .topic_labels_.

@PhPv, you can try changing the source code of your bertopic installation with my suggested fix in def topic_labels_, or you can monkey patch it, and let me know if the result is as expected:

import pandas as pd
import os
from bertopic import BERTopic
from sentence_transformers import SentenceTransformer
from umap import UMAP
from bertopic.representation import KeyBERTInspired

def fixed_topic_labels_(self):
    """Map topic IDs to their labels.
    A label is the topic ID, along with the first four words of the topic representation, joined using '_'.
    Zeroshot topic labels come from self.zeroshot_topic_list rather than the calculated representation.

    Returns:
        topic_labels: a dict mapping a topic ID (int) to its label (str)
    """
    topic_labels = {
        key: f"{key}_" + "_".join([word[0] for word in values[:4]])
        for key, values in self.topic_representations_.items()
    }
    if self._is_zeroshot():
        # Need to correct labels from zero-shot topics
        topic_id_to_zeroshot_label = {
            self.topic_mapper_.get_mappings()[topic_id]: self.zeroshot_topic_list[zeroshot_topic_idx]
            for topic_id, zeroshot_topic_idx in self._topic_id_to_zeroshot_topic_idx.items()
        }
        topic_labels.update(topic_id_to_zeroshot_label)
    return topic_labels
BERTopic.topic_labels_ = property(fixed_topic_labels_)

root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__)))

representation_model = KeyBERTInspired()

sentence_model = SentenceTransformer(
    'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
)

zeroshot_topic_list = pd.read_csv(
    f"{root_dir}/zeroshot_topic_list.csv",
    sep=";",
    index_col=0,
)

df = pd.read_csv(
    f"{root_dir}/df.csv",
    sep=";",
    index_col=0,
)

zeroshot_topic_list = list(
    zeroshot_topic_list["Name"].str.replace(r"^\d+_", "", regex=True)
)

embeddings = sentence_model.encode(list(df["text"]), show_progress_bar=False)
umap_model = UMAP(
    n_neighbors=15,
    n_components=5,
    min_dist=0.0,
    metric="cosine",
    random_state=42,
)
topic_model = BERTopic(
    language="multilingual",
    umap_model=umap_model,
    embedding_model=sentence_model,
    top_n_words=20,
    n_gram_range=(1, 2),
    min_topic_size=50,
    nr_topics=None,
    vectorizer_model=None,
    representation_model=representation_model,
    zeroshot_topic_list=zeroshot_topic_list,
    zeroshot_min_similarity=0.7,
)
topics, probs = topic_model.fit_transform(list(df["text"]), embeddings)
print(topic_model.get_topic_info())

Thoughts?

MaartenGr commented 1 month ago

@ianrandman Thanks for responding so quickly and diving into this! I agree, based on the necessary changes the former that you mentioned seems to be the most straightforward, requires minimal changes, and is in line with the current state of TopicMapper (although that might change in the future but that's still uncertain).

If you have the time, a PR would be appreciated.

PhPv commented 1 month ago

Yeah, that's work! thx so much for fix and for the entire Bertopic library as a whole

ianrandman commented 1 month ago

@MaartenGr I am not sure I can get around to making a PR this week. After thinking about it some more, I am less certain that the proper way is to use self.topic_mapper_.get_mappings() whenever we access self._topic_id_to_zeroshot_topic_idx. I am not quite familiar with why the history of mappings is stored in the first place with the TopicMapper.

Don't other instance variables maintain the current state of the topics rather than relying on the TopicMapper (such as topics_, topic_labels_, topic_sizes_, topic_representations_, etc?). It feels wrong to need to use self.topic_mapper whenever we read values from self._topic_id_to_zeroshot_topic_idx. Maybe you can help explain the overall purpose of the TopicMapper a bit better and when and why it is used? I am thinking of the use-case where after fitting, if we call the reduce topics function, won't those other instance variables I described change accordingly?

MaartenGr commented 1 month ago

@ianrandman

I am not sure I can get around to making a PR this week.

No problem, I can do it if the change is small. Just to be sure, all that is needed to change is to add self.topic_mapper_.get_mappings()[topic_id] whenever we loop over topics from self_topic_id_to_zeroshot_topic_idx, right?

After thinking about it some more, I am less certain that the proper way is to use self.topicmapper.get_mappings() whenever we access self._topic_id_to_zeroshot_topic_idx. I am not quite familiar with why the history of mappings is stored in the first place with the TopicMapper.

Don't other instance variables maintain the current state of the topics rather than relying on the TopicMapper (such as topics_, topiclabels, topicsizes, topicrepresentations, etc?). It feels wrong to need to use self.topic_mapper whenever we read values from self._topic_id_to_zeroshot_topic_idx. Maybe you can help explain the overall purpose of the TopicMapper a bit better and when and why it is used?

The TopicMapper is needed to keep track of the order of mappings between (for instance) unsorted and sorted topics. Often times, only the topics themselves (so the list of integers) are re-sorted for specific purposes and not all other information is immediately available at the time. Therefore, it is necessary to keep track of mappings between different stages of sorting (which can happen multiple times). During .partial_fit for instance, topics might need to be re-sorted multiple times to keep track of ever-increasing topics but still be usable during inference (so the mapping between the original prediction and sorted order).

So it's used for finding the mapping between the original topics (output of cluster model without sorting) and the current state of the topics (after sorting, merging, updating, etc). It used to also be used to find the mapping between the current state of the topics and the previous state but I don't think it's being used anymore.

But mostly, the TopicMapper allows for tracking topics in different states which (hopefully in a future version) allows for additional representations across different levels of topics.

I am thinking of the use-case where after fitting, if we call the reduce topics function, won't those other instance variables I described change accordingly?

They should, so I'm also a bit surprised that it isn't the case here and that an additional mapping is needed. But mostly, I would guess that the self._topic_id_to_zeroshot_topic_idx isn't updated whenever the topics are re-sorted. I agree that ideally, everything related to the topics should get updated whenever they are resorted but it seems that wasn't the case here.

Thinking about it, it might as simple as updating self._topic_id_to_zeroshot_topic_idx whenever we run ._sort_mappings_by_frequency according to the newly created sorted topics.

ianrandman commented 1 month ago

Thinking about it, it might as simple as updating self._topic_id_to_zeroshot_topic_idx whenever we run ._sort_mappings_by_frequency according to the newly created sorted topics.

I believe this would be a correct solution. I like this better. I think it is best if whenever self._topic_id_to_zeroshot_topic_idx is used, one can assume it is aligned with the current state of topics.

@PhPv's example is good to verify correctness for the case of fitting. It would be good to add a unit test (or at the very least test manually) for topic reduction in this case.

MaartenGr commented 1 month ago

@ianrandman It took a while to get a working PR out since there were a number of issues that needed to be resolved after more thorough exploration. Having said that, I think #2105 should solve the issues seen here.

@PhPv Could you test whether that PR works for you? I believe it should but many of the tests needed to be done manually to check whether the output was correct.