oxford-cs-deepnlp-2017 / practical-2

Oxford Deep NLP 2017 course - Practical 2: Text Classification
https://www.cs.ox.ac.uk/teaching/courses/2016-2017/dl/
112 stars 93 forks source link

my neural net only predicts 'ooo'. #2

Open stevenhao opened 7 years ago

stevenhao commented 7 years ago

I implemented the most basic neural net (following the instructions) and it is not performing very well. I'm using Bag of Means to do document embedding which uses a Word2Vec model trained on the ted text.

I suspect that I have some sort of bug, as I'm a beginner with PyTorch. If the instructors don't mind, I'd like to share my code: [removed] It is mostly modeled off this tutorial.

kwea123 commented 7 years ago

I also experince this problem, I think it's because the distribution is too unbalanced. something like [ 0.01629133, 0.01677048, 0.06947772, 0.1825587 , 0.01149976, 0.08145664, 0.0785817 , 0.53905129] So if you sample batches as it stands, it's quite probable that your batches contain many instances of label 'ooo' then the model won't learn anything. It's just my personal thought, don't know if it's correct, but I did the same thing as you and I got both ~54% on training and test sets, which is exactly the ratio of class 'ooo'. Now I change my way of batching so that the batches contain equal number of each class (e.g. if the batch size is 48 then it contains 6 instances of each class), in this way I can achieve training acc 99% but the test acc is still around 60%. I'm stuck and don't know how to do better. I tried multilayers, different numbers of units and dropouts but they don't help a lot. I train also the word vectors according to this problem so word representation is not a problem I think.

iassael commented 7 years ago

@stevenhao that's good work but we would kindly ask you not share your code.

The dataset is unbalanced, so the random baseline will be much higher, however, you could still attempt to achieve higher accuracy. How much is your accuracy?

Tommalla commented 7 years ago

I have a similar problem. No matter how many different methods I try, my network always learns to predict ooo regardless of the data.

Is there an important trick we're missing here? Standardizing the data in some way perhaps?