MilaNLProc / contextualized-topic-models

A python package to run contextualized topic modeling. CTMs combine contextualized embeddings (e.g., BERT) with topic models to get coherent topics. Published at EACL and ACL 2021 (Bianchi et al.).
MIT License
1.19k stars 143 forks source link

Could I do inference when only having bert text? #40

Closed lin-sung closed 3 years ago

lin-sung commented 3 years ago

Description

Could I do inference when only having bert text? For example, I have a trained CombinedTM model, and I would like to know the topic distribution of the input "I like apples."

What I Did

I wonder could we achieve the goal by the following code shown below.

bert_texts = ["I like apples."]

qt = QuickText("distiluse-base-multilingual-cased",
                text_for_bert=bert_texts,
                text_for_bow=bert_texts)

testing_dataset = qt.load_dataset()

# n_sample how many times to sample the distribution (see the doc)
ctm.get_thetas(testing_dataset, n_samples=20)

However, in QuickText, it will construct self.bow on the fly, and the vocabulary will be different from the one used in training data. I think the mismatch will make the model produce wrong predictions. Do I understand correctly? Also, how could I achieve the goal through current codes?

vinid commented 3 years ago

Yes, your interpretation is totally correct.

I believe that if you want to do inference with BERT the best thing is to sue the ZeroShotTM. This model should take care of all these problems for you (since it will use only BERT to build the representations and at test time it will "ignore" the BoW mismatch.)

You can also do this with the CombinedTM model, there are two ways I can think of to do this.

1 - if you know the indexes of the test data in the entire dataset, you can use the torch.utils.data.Subset utility to get a subset of the dataset object for training and a subset of the dataset object for testing. This will allow your models to be consistent in BoW size.

2 - if you want to build the bow ONLY with training documents and then remove what is missing at test time, you need to take care of the following things:

You can look at the code here to get an idea of what you need to modify.

Essentially, the object you need to create is the following:

training_dataset = CTMDataset(self.bow, self.data_bert, self.idx2token) 

self.data_bert comes from applying the function bert_embeddings_from_list as shown in the link above. So you will need to create the BERT embeddings for your test documents.

self.idx2token, is the same you have used during training. You can extract it from the QuickText training object.

The only thing that requires a bit more work is the self.bow. What you need to do for all the documents you have is to create a bow with the same vocabulary you had during training.

A high level implementation looks like this (not that this might be very slow and also has high RAM usage, but it should give you the general idea)

testing_bow []
for sentence in test_corpus:
    local_bow = np.zeros(len(vocab)) # create a bow len object
    tokens = sentence.split() # tokenize input text
    for token in tokens: 
        if token vocab: # if the test token has also been seen in the training, that's a + 1 for the bow
            local_bow[vocab.index(token)] = +1
    testing_bow.append(local_bow) # we append to the list of testing inputs

you then need to make this testing_bow a sparse matrix.

Hope this helps :)

lin-sung commented 3 years ago

Thank you. Your explanations help a lot.

I current change to use ZeroShotTM to avoid to use BOW data. But I hope to in the future I still can use CombinedTM because it performs a bit better. I think the second way will be more practical because, most of the time, users might not know the test data during the training phase. About the second approach, would it be easier to save the vocabulary in CountVectorizer, and load it when testing?

vinid commented 3 years ago

Which BERT model are you using? I see the multilingual one in the snippet you posted, maybe use an English-specific model can help (that is if you are using English text). In the experiments we often used 'bert-base-nli-mean-tokens` but we support all the sentence-transformers and we inherit from it the support for all huggingface models.

For the CountVectorizer thing, yes. Give me a few days, I'll try to prioritize this in the dev pipeline. :)

vinid commented 3 years ago

So I have introduced a simple method that should allow you to directly pass the input you want to the QuickText object. So you can generate the embeddings before this step and then fill a new QuickText object with your custom data.

It has not been heavily tested yet, but you can find it on the master release (you'll thus need to pip install the master repo).

Hope this helps a bit :)