weinman / cnn_lstm_ctc_ocr

Tensorflow-based CNN+LSTM trained with CTC-loss for OCR
GNU General Public License v3.0
498 stars 170 forks source link

Imbalanced classes #8

Closed skalinin closed 6 years ago

skalinin commented 6 years ago

Hi! I am new in tensorflow, now I am trying to figure out how to work with your model. The thing is, I need to put my own data in it, but my dataset is very imbalanced (for example, the class ‘q’ occurs about 100 times in the dataset, but the class ‘a’ may be more than 10 thousand times). What should I do? How can I use class weights in your code? I think it may looks like this. In function ‘ctc_loss_layer’ in ‘model.py’ we have rnn_logits - this is output from RNN, what if I multiply it by class weights before put it in CTC loss? Then CTC loss would have the greater weights for rare classes, and that would impact to backpropagation. Am I right? Could you please help me?

weinman commented 6 years ago

I don't think what you suggest is the right approach. Two issues come to mind. CTC ignores the location of the individual elements in the input image, so I cannot see how you would practically target individual sequence items. If you could, one ex post facto way to adjust your prior in a learned conditional model to use Bayes' rule.

P(class|obs) = P(obs|class) P(class) / P(obs) P'(class|obs) = P(obs|class) P'(class) / P(obs) = P(class|obs) * P'(class) / P(class)

That is, you can adjust your discriminant function (in a non-sequential model) by rescaling with the ratio of training data to test data distributions.

skalinin commented 6 years ago

Thanks for answering! So, if I get you right, we may have 100 ‘q’ in training set and we know that it is not enough. In the test data it may be 10,000 ‘q’. So let say we have trained our model with what we have, and on the last step, we run the model and just multiply output probability ‘q’ by P’(q)/P(q) where P’(q) is 10,000/total_number_of_symbols_in_test and P(q) is 100/total_number_of_symbols_in_train

So, in that way we a little bit increase/decrease output probability of ‘q’. It sounds tricky to me. Could you please tell if i understand you right? And what you can say about MJSynth dataset imbalance? The class/symbols in that dataset not balanced either. Isn’t it affect to the results when we test the model trained on MJsynth on the other datasets?

weinman commented 6 years ago

There are no easy solutions. Any sequence model, such as the one in this repo, will perform best on test data with the same statistics (as captured by the model) as the training data.