webis-de / small-text

Active Learning for Text Classification in Python
https://small-text.readthedocs.io/
MIT License
562 stars 61 forks source link

EmbeddingKMeans sample() got an unexpected keyword argument 'embeddings_proba' #5

Closed HannahKirk closed 3 years ago

HannahKirk commented 3 years ago

Hi,

I'm trying to run the 01-active-learning-for-text-classification-with-small-text-intro.ipynb notebook with EmbeddingKMeans. I set the query strategy to EmbeddingKMeans and initialised the active learner:

transformer_model = TransformerModelArguments(transformer_model_name)
clf_factory = TransformerBasedClassificationFactory(transformer_model, 
                                                    num_classes, 
                                                    kwargs=dict({'device': 'cuda', 
                                                                 'mini_batch_size': 32,
                                                                 'early_stopping_no_improvement': -1
                                                                }))
query_strategy = EmbeddingKMeans()

active_learner = PoolBasedActiveLearner(clf_factory, query_strategy, train)
labeled_indices = initialize_active_learner(active_learner, train.y)

According to strategies.py embeddings can be passed as None into the class EmbeddingBasedQueryStrategy(Query Strategy) .

However I am getting the error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-8-9f2deb8154a9> in <module>()
     23 for i in range(num_queries):
     24     # ...where each iteration consists of labelling 20 samples
---> 25     q_indices = active_learner.query(num_samples=20)
     26 
     27     # Simulate user interaction here. Replace this for real-world usage.

2 frames
/usr/local/lib/python3.7/dist-packages/small_text/active_learner.py in query(self, num_samples, x, query_strategy_kwargs)
    186                                                          self.y,
    187                                                          n=num_samples,
--> 188                                                          **query_strategy_kwargs)
    189         return self.queried_indices
    190 

/usr/local/lib/python3.7/dist-packages/small_text/query_strategies/strategies.py in query(self, clf, x, x_indices_unlabeled, x_indices_labeled, y, n, pbar, embeddings, embed_kwargs)
    249                                                   n, embeddings)
    250                 else:
--> 251                     raise e
    252 
    253         return x_indices_unlabeled[sampled_indices]

/usr/local/lib/python3.7/dist-packages/small_text/query_strategies/strategies.py in query(self, clf, x, x_indices_unlabeled, x_indices_labeled, y, n, pbar, embeddings, embed_kwargs)
    241                     if embeddings is None else embeddings
    242                 sampled_indices = self.sample(clf, x, x_indices_unlabeled, x_indices_labeled,
--> 243                                               y, n, embeddings, embeddings_proba=proba)
    244             except TypeError as e:
    245                 if 'got an unexpected keyword argument \'return_proba\'' in e.args[0]:

TypeError: sample() got an unexpected keyword argument 'embeddings_proba'

Could you advise a solution?

chschroeder commented 3 years ago

Hi,

thank you for reporting this. This was a bug, caused by a recent change to the sample() function, which the tests unfortunately missed.

I just fixed this and released a new version. After reinstalling small-text your example should work now.