MaartenGr / BERTopic

Leveraging BERT and c-TF-IDF to create easily interpretable topics.
https://maartengr.github.io/BERTopic/
MIT License
6.2k stars 765 forks source link

Running time is too long when using Zero-Shot Classification #1919

Closed syGOAT closed 7 months ago

syGOAT commented 7 months ago

For test, I just set n_clusters=5 in KMeans and input only 20 docs.

umap_model = UMAP(n_neighbors=20, n_components=15, min_dist=0.0, metric='cosine', random_state=42)
cluster_model = KMeans(n_clusters=5, random_state=42)  
vectorizer_model = CountVectorizer(stop_words="english")

model = BERTopic(embedding_model='/root/autodl-tmp/fhy/bertopic_topic/paraphrase-MiniLM-L6-v2', 
                 umap_model=umap_model,
                 hdbscan_model=cluster_model,
                 vectorizer_model=vectorizer_model,
)
topics, probabilities = model.fit_transform(abstracts[:20], nr_repr_docs=1)

The code above ran for only a few seconds. But the code bellow:

candidate_topics = ['Material Science: General', 'Physical Chemistry', 'Chemistry: General', ......]
# about 70 elements

representation_model = ZeroShotClassification(candidate_topics, model="./bart-large-mnli", )
model.update_topics(abstracts[:20], representation_model=representation_model)

It ran for more than 20 minutes without ending. It is too long. I set only 3 candidate topics (ZeroShotClassification(candidate_topics[:3], model="./bart-large-mnli")) and code ended after 4mins. So the problem may be that there are too many candidate topics. But I don't think 70 is a lot. Maybe you can optimize the parallel batch processing capability, when the pipeline of transformers is not the problem?

MaartenGr commented 7 months ago

It ran for more than 20 minutes without ending. It is too long. I set only 3 candidate topics (ZeroShotClassification(candidate_topics[:3], model="./bart-large-mnli")) and code ended after 4mins.

It is difficult to say whether this is short or long without knowing a bit more about your environment. Are you using a GPU? If not, then that might explain the problem you are facing.

So the problem may be that there are too many candidate topics. But I don't think 70 is a lot. Maybe you can optimize the parallel batch processing capability, when the pipeline of transformers is not the problem?

Most likely, the compute time is a result of the transformers pipeline and parallel batch processing is a bit more involved when you are dealing with GPUs. Generally, I think this should be reasonably fast if you are using a GPU.

syGOAT commented 7 months ago

@MaartenGr Thank you for reply! I used a gpu. I think this is where the problem lies: https://github.com/MaartenGr/BERTopic/blob/6c9eb6e72a881077ac59c35752d26e391bfe4c49/bertopic/representation/_zeroshot.py#L72C9-L74C104 topic_descriptions had 100 topics (my scene) and a lot of words. It was passed directly to pipline ZeroShotClassificationPipeline without batched. I think it was the large amount of data processed by the pipeline at one time that leaded to the long running time. Maybe add an arguement batch_size in ZeroShotClassification could be better?

MaartenGr commented 7 months ago

Which GPU are you using? One or the other makes quite a bit of difference.

I think it was the large amount of data processed by the pipeline at one time that leaded to the long running time. Maybe add an arguement batch_size in ZeroShotClassification could be better?

That could be a possibility but have you tested whether that is indeed the issue? You can test this out by simply adopting the code you referenced there to supply custom batches. Part of BERTopic is its modularity, so adapting it should be straightforward.

syGOAT commented 7 months ago

@MaartenGr I found the problem. Function __call__ of class Pipeline in transformers has an arguement batch_size: https://github.com/huggingface/transformers/blob/caa5c65db1f4db617cdac2ad667ba62edf94dd98/src/transformers/pipelines/base.py#L1157C3-L1170C46 We could consider adding it in pipeline_kwargs: https://github.com/MaartenGr/BERTopic/blob/de7376d3d42960e787a1115e5fe69fb726a7a33d/bertopic/representation/_zeroshot.py#L55C1-L74C104

MaartenGr commented 7 months ago

Ah, in that case it is also already implemented right? Just do something like this:

representation_model = ZeroShotClassification(candidate_topics, model="./bart-large-mnli", pipeline_kwargs={"batch_size":32})
syGOAT commented 7 months ago

@MaartenGr Yes. The problem has been solved. Thank you so much!