yangarbiter / multilabel-learn

multilabel-learn: Multilabel-Classification Algorithms
35 stars 4 forks source link

why the outputs of the RethinkNet is all close to zero? #1

Open Tenyn opened 4 years ago

Tenyn commented 4 years ago

I trained the net on bibtex dataset. The loss function is binary crossentropy.

Thank you.

yangarbiter commented 4 years ago

Probably not "all" are close to zero? Since bibtex dataset has a lot of labels, it is possible that many of the labels are simply not related to the feature, thus having an output close to zero

Tenyn commented 4 years ago

Thanks for your reply.

The outputs of all labels are below 0.1 so i wander if the wrong model was built.

The model is shown as follows: ` def RethinkNet(input_shape, n_labels):

inputs = Input(shape=input_shape[1:])
x = RepeatVector(input_shape[0])(inputs)

x = Dense(128, kernel_regularizer=l2(0.0001), activation='relu')(x)

x = LSTM(128, return_sequences=True,
         recurrent_regularizer=l2(0.0001),
         kernel_regularizer=l2(0.0001),
         recurrent_dropout=0.25, 
         activation='sigmoid')(x)

outputs = Dense(n_labels,kernel_regularizer=l2(0.0001), activation='sigmoid')(x)

model = Model(inputs=[inputs], outputs=[outputs])
model.compile(loss='binary_crossentropy', optimizer=Nadam(lr=0.001), metrics=['accuracy'])

return model

`

yangarbiter commented 4 years ago

Have you train the model and have the loss converged?

Tenyn commented 4 years ago

I trained it 300 epochs, and the loss converges to 0.0746

yangarbiter commented 4 years ago

Which cost function are you training the model with? If you trained RethingNet with Hamming loss (without the reweighting), it is possible with such result due to the nature of Hamming loss. Because with Hamming loss, when the positive label count is small, it is easy to have a low Hamming loss when you predict everything 0. One thing to check is to check with the model is whether the Hamming loss is small and play with other cost functions like F1 score.

Tenyn commented 4 years ago

The cost function is binary crossentropy. So i dont understand why the outputs have no labels close to 1.

yangarbiter commented 4 years ago

For training the RethinkNet, it is in fact training on a weighted binary crossentropy loss.

You can check the implementation here. https://github.com/yangarbiter/multilabel-learn/blob/master/mlearn/models/rethinknet/rethinkNet.py

If using the binary crossentropy alone without changing the weight. Let's say the dataset have 1 label labeled 1 and another 99 labels labeled 0, the easiest solution for the model is to learn is to predict every label being 0. Then you would have a 99% accuracy on the labels. Thus the weighting on the binary crossentropy is important.

Tenyn commented 4 years ago

Thanks. I will try the weighted binary crossentropy loss.