ThilinaRajapakse / simpletransformers

Transformers for Information Retrieval, Text Classification, NER, QA, Language Modelling, Language Generation, T5, Multi-Modal, and Conversational AI
https://simpletransformers.ai/
Apache License 2.0
4.11k stars 727 forks source link

Prediction always returns 0 #18

Closed x3p8z closed 5 years ago

x3p8z commented 5 years ago

First of all, thank you very much for creating this cool library!

I tried the Yelp example and it works fine. However, when I replace the Yelp data with my own, it evaluates to only TN and FN (in other words, only the class 0 is predicted). In my dataset, the label 0 occurs approximately 80% of the time and the label 1 about 20%.

I removed the output and cache folder every time and I also tried to vary some parameters, like batch size.

I am using the most recent simpletransformers version from github.

I can do classification without any problems with the same dataset using custom NNs or classifiers from sklearn.

Am I missing something? I appreciate any pointers that you can give.

ThilinaRajapakse commented 5 years ago

Are the columns in the correct order for train_df? The first column should be the text and the second column should be the labels.

x3p8z commented 5 years ago

Thanks for your quick response!

The columns are in the correct order. I also tried to switch them. When doing this, I get an error that ints do not have the split() function.

Maybe some other note. My text is longer than the Yelp texts. But I assume this will be truncated by the sample size parameter. So it should not be a Problem. (?)

ThilinaRajapakse commented 5 years ago

The length shouldn't matter as you said. Unless all your text ends up looking similar once they are truncated, but that is unlikely.

Then again, if it's an issue with the training/evaluation code, it should be a problem with the Yelp data as well.

Can you try undersampling your class 0 so that the data is distributed evenly across both classes? I don't think the issue with the class (im)balance, but doesn't hurt to rule it out.

x3p8z commented 5 years ago

I finished my experiments. It did turn out that the label distribution does actually matter. If the labels are distributed evenly (50:50), the classification works as expected.

To be on the safe side, I ran the same code with the original data distribution (approximately 80:20) and it returned all 0s again.

I assume this has to do with the batch size of 8. What is your take on this?

ThilinaRajapakse commented 5 years ago

I've used Transformers on similarly imbalanced datasets but I've never had this issue. All predictions being the same class makes me suspect something odd is going on.

Is your model overfitting? How large is your dataset? What is your lr and num_train_epochs?

x3p8z commented 5 years ago

I agree. It seems rather strange that all predictions belong to the same class.

Here are some more details regarding my setup. I appreciate any input and suggestions!

The dataset consists of 5000 samples.

Regarding the loss during training: I do not have the numbers anymore, but I remember that it fluctuated somewhere between 0.65 and 0.75.

I mostly used the default parameters from your tutorial. You can find them below with changes highlighted in bold.

params = { 'output_dir': 'outputs/', 'cache_dir': 'cache/',

'fp16': False, 'fp16_opt_level': 'O1', 'max_seq_length': 128, 'train_batch_size': 8, 'eval_batch_size': 8, 'gradient_accumulation_steps': 1, 'num_train_epochs': 1, 'weight_decay': 0, 'learning_rate': 4e-5, 'adam_epsilon': 1e-8, 'warmup_ratio': 0.06, 'warmup_steps': 0, 'max_grad_norm': 1.0,

'logging_steps': 50, 'evaluate_during_training': False, 'save_steps': 2000, 'eval_all_checkpoints': True, 'use_tensorboard': True,

'overwrite_output_dir': True, 'reprocess_input_data': True, }

ThilinaRajapakse commented 5 years ago

Try a training_batch_size of 4 and a learning rate of 1e-5. I'm trying to see whether the model is interpreting class 1 as noise in the data.

Also, how was the performance of your model when trained with 50/50 distribution?

x3p8z commented 5 years ago

Alright, thanks for your input. I will train the model with the proposed parameters and give you an update later.

Regarding your question: The performance with the 50/50 distribution (for training) was bad. I got an accuracy of only 0.57 (eval dataset has about 1,000 samples with 80:20 distribution) and f1 score of 0.45.

ThilinaRajapakse commented 5 years ago

Do you remember what the score was like when you tried with sklearn and stuff?

x3p8z commented 5 years ago

The results with sklearn were around 0.80 accuracy and 0.65 f1 score.

Also, the new results with SimpleTransformers are in now. Changing the batch size and learning rate led to an improvement. Accuracy is now up to 0.73 and f1 score 0.54. However, I still undersampled class 0.

I will try the same settings now with the original data and I will also try more training epochs later.

Please let me know in case you have any further suggestions!

ThilinaRajapakse commented 5 years ago

I guess it's now safe to say that the issue is being caused by something to do with the data, rather than an implementation error. So, I think the next thing to try would be to increase the sequence length and see if it helps. I have a feeling it might.

AliOsm commented 5 years ago

I faced the same problem when I worked on propaganda detection task: https://propaganda.qcri.org/nlp4if-shared-task/index.html. The dataset wasn't balanced, so I changed the prediction threshold to be 0.25 instead of 0.5 which give me much better results. Try to do the same thing. Note that I was training RNN model based on BERT features and I didn't use simpletransformers.

x3p8z commented 5 years ago

Hello everyone,

thanks for your comments, @ThilinaRajapakse and @AliOsm. I will try them out in the coming days.

@ThilinaRajapakse, I agree that the issue is very likely not with the library, but rather with the data. Thanks for your efforts supporting me!

ThilinaRajapakse commented 5 years ago

You're welcome! Feel free to bounce any ideas off me anytime.