webis-de / small-text

Active Learning for Text Classification in Python
https://small-text.readthedocs.io/
MIT License
557 stars 61 forks source link

fit() got an unexpected keyword argument 'validation_set' #7

Closed HannahKirk closed 2 years ago

HannahKirk commented 2 years ago

Hi,

I'm initializing an active learner for an Sklearn model with specific validation indices. Minimal code example is:

def initialize_learner(learner, train, test_sets, init_n): 
  print('\n----Initalising----\n')
  iter_results_dict = {}
  iter_preds_dict = {}
  #Initialize the model - This is required for model-based query strategies.
  indices_neg_label = np.where(train.y == 0)[0]
  indices_pos_label = np.where(train.y == 1)[0]
if init_n ==4:
     x_indices_initial = np.concatenate([np.random.choice(indices_pos_label, int(init_n/2), replace=False),
  np.random.choice(indices_neg_label, int(init_n/2), replace=False)])
      x_indices_initial = x_indices_initial.astype(int)
      y_initial = np.array([train.y[i] for i in x_indices_initial])
      val_indices = x_indices_initial[1:3]
      learner.initialize_data(x_indices_initial, y_initial, x_indices_validation=val_indices) # use half indices for validation
 iter_results_dict[0], iter_preds_dict[0] = evaluate(learner, train[x_indices_initial], test_sets, x_indices_initial)
 return learner, x_indices_initial, iter_results_dict, iter_preds_dict 

The error I am getting is fit() got an unexpected keyword argument 'validation_set'. Digging into the code, it seems like if you pass x_indices_validation as not None this shouldn't happen.

Do you have any suggestions?

chschroeder commented 2 years ago

Hi,

which sklearn model are you using exactly? Until now, I thought that scikit-learn classifiers either don't have a validation set or create their own from the training data tha tis provided. Therefore, the current intended use would be to omit the validation set when using with scikit learn classifiers.

But maybe this assumption was wrong?

HannahKirk commented 2 years ago

I am using the ConfidenceEnhancedSVM.

That makes sense. I guess this validation set is only useful for early stopping in the train scheduler of the transformers model. Might be worthwhile adding to documentation so its clear fit() of a non-transformers model doesnt take a validation set. Thank you!