MaartenGr / BERTopic

Leveraging BERT and c-TF-IDF to create easily interpretable topics.
https://maartengr.github.io/BERTopic/
MIT License
6.14k stars 766 forks source link

Reduce outliers embeddings strategy throws errors when setting Custom labels #1490

Open manmax31 opened 1 year ago

manmax31 commented 1 year ago

I have been following your tutorial on how to use llama to get better topic names.

The only difference between yours and mine is that I am using Alibaba's Qwen 7b model which I find beats any 7b or 13b model. I am setting the labels after doing outlier reduction using embeddings strategy.

The issue is: If I reduce outliers using embeddings, -1 topic goes away and hence I get the error: Make sure that topic_labels contains the same number of labels as that there are topics.

If I use c-tf-idf or distributions strategy to reduce outliers, there is no issue.

Would you have any suggestions?

Here is the code:

## Embedding model
embedding_model = SentenceTransformer(
    "BAAI/bge-large-en"
)

embeddings = embedding_model.encode(
    docs, normalize_embeddings=True, device="cuda:0", show_progress_bar=True
)

## Representation Model
# MMR
mmr = MaximalMarginalRelevance(diversity=0.7)

# KeyBert inspired
kbi = KeyBERTInspired()

# Generative model
model_id = "Qwen/Qwen-7B"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
generation_config = transformers.GenerationConfig.from_pretrained(model_id)

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_id,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    device_map="cuda:2",
)
model.eval()
model.tie_weights()

generator = transformers.pipeline(
    model=model,
    tokenizer=tokenizer,
    task="text-generation",
    temperature=0.01,
    max_new_tokens=50,
    repetition_penalty=1.15,
    top_p=0.95,
    generation_config=generation_config,
)

# Prompt
system_prompt = """
<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant for labeling topics.
<</SYS>>
"""

# Example prompt demonstrating the output we are looking for
example_prompt = """
I have a topic that contains the following documents:
- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
- Meat, but especially beef, is the word food in terms of emissions.
- Eating meat doesn't make you a bad person, not eating meat doesn't make you a good one.

The topic is described by the following keywords: 'meat, beef, eat, eating, emissions, steak, food, health, processed, chicken'.

Based on the information about the topic above, please create a short label of this topic. Make sure you to only return the label and nothing more.
[/INST] Environmental impacts of eating meat
"""

main_prompt = """
[INST]
I have a topic that contains the following documents:
[DOCUMENTS]
The topic is described by the following keywords: '[KEYWORDS]'.
Based on the information about the topic above, please create a short label of this topic. Make sure you to only return the label and nothing more.
[/INST]
"""
prompt = system_prompt + example_prompt + main_prompt

# Text generation with Qwen
qwen = TextGeneration(generator, prompt=prompt)

# All representation models
representation_model = {
    "KeyBERT": kbi,
    "Qwen": qwen,
    "MMR": mmr,
}

# UMAP
umap_model = UMAP(
    n_neighbors=15, n_components=5, min_dist=0.0, metric="cosine", random_state=50
)

# HDBSCAN
hdbscan_model = HDBSCAN(
    core_dist_n_jobs=-1,
    min_cluster_size=20,
    metric="euclidean",
    cluster_selection_method="leaf",
    prediction_data=True,
)

## Topic Model
topic_model = BERTopic(
    # Sub-models
    embedding_model=embedding_model,
    umap_model=umap_model,
    hdbscan_model=hdbscan_model,
    representation_model=representation_model,
    # Hyperparameters
    top_n_words=10,
    verbose=True,
    nr_topics="auto",
)

# Train model
topics, probs = topic_model.fit_transform(docs, embeddings)

# Reduce outliers
new_topics = topic_model.reduce_outliers(
    docs, topics, probabilities=probs, strategy="embeddings"
)

topic_model.update_topics(docs, topics=new_topics)

# Set LLM labels
qwen_labels = [
    label[0][0].split("\n")[0].strip()
    for label in topic_model.get_topics(full=True)["Qwen"].values()
]

topic_model.set_topic_labels(qwen_labels)
MaartenGr commented 1 year ago

It might just be that the issue is from the "disappearance" of the -1 class through the outlier reduction. I would advise doing the following instead:

# Reduce outliers
new_topics = topic_model.reduce_outliers(
    docs, topics, probabilities=probs, strategy="embeddings"
)
topic_model.update_topics(docs, topics=new_topics)

# Update the attribute that checks whether there are still outliers
topic_model._outliers = 0

# Set LLM labels
qwen_labels = [
    label[0][0].split("\n")[0].strip()
    for label in topic_model.get_topics(full=True)["Qwen"].values()
]

topic_model.set_topic_labels(qwen_labels)

I believe this is a known issue for which there is a PR available that I need to check a bit more in-depth.

manmax31 commented 1 year ago

Thank you but still throws the same error.

MaartenGr commented 1 year ago

Could you check whether qwen_labels indeed contains fewer labels than is found in topic_model.topic_labels_?

manmax31 commented 1 year ago

It is the other way around: qwen_labels has 1 more label than topic_model.topic_labels_

MaartenGr commented 1 year ago

In that case, I would advise checking if the order of qwen_labels matches with topic_model.topic_labels_ and topic_model.custom_labels_. I expect the input of qwen_labels to have one label too many which should be removed. I think it might be the outlier class which could be removed but you will have to check.

Keamww2021 commented 1 year ago

Hello,

I faced same problem.

How can I remove the outlier class from the qwen_lables

MaartenGr commented 1 year ago

@Keamww2021 Simple remove the first outlier label from the list and I believe it should work. Do note though that it is difficult to say without seeing your exact code/versions/environment/etc.