tkipf / gcn

Implementation of Graph Convolutional Networks in TensorFlow
MIT License
7.1k stars 2k forks source link

multi-label classification #119

Open Chengmeng94 opened 5 years ago

Chengmeng94 commented 5 years ago

Hello! I have a problem with a multi-label classification. I changed loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=preds, labels=labels) and correct_prediction = tf.equal(preds, labels) in metrics.py . And use tf.nn.sigmoid(self.outputs) to predict in models.py. Do you think it is ok, or you have other suggestions? Thank you!

tkipf commented 5 years ago

Looks good to me on first glance. Make sure that the function tf.nn.sigmoid_cross_entropy_with_logits receives logits and not outputs from a sigmoid.

On Tue, May 28, 2019 at 4:48 PM Chengmeng94 notifications@github.com wrote:

Hello! I have a problem with a multi-label classification. I changed loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=preds, labels=labels) and correct_prediction = tf.equal(preds, labels) in metrics.py . And use tf.nn.sigmoid(self.outputs) to predict in models.py. Do you think it is ok, or you have other suggestions? Thank you!

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/tkipf/gcn/issues/119?email_source=notifications&email_token=ABYBYYCTI4XMEB2T5WB5MP3PXXAETA5CNFSM4HQH6KH2YY3PNVWWK3TUL52HS4DFUVEXG43VMWVGG33NNVSW45C7NFSM4GWLGZHQ, or mute the thread https://github.com/notifications/unsubscribe-auth/ABYBYYAQSAPF33T37RBGW2LPXXAETANCNFSM4HQH6KHQ .

mrano commented 5 years ago

It seems that correct_prediction only have 'false's because of preds is always decimal while label is 0 or 1(for binary classification)?

Chengmeng94 commented 5 years ago

It seems that correct_prediction only have 'false's because of preds is always decimal while label is 0 or 1(for binary classification)?

I think that for a node, in a multi-label task, all labels should be predicted correctly. The following is the code I modified, welcome to discuss.

` def gettoplist(ytest):

ytest = [[] for in range(ytest.shape[0])]

cy = sparse.coo_matrix(ytest)

for i, j in zip(cy.row, cy.col):

    y_test[i].append(j)

toplist = [len(l) for l in y_test]

return toplist

`

` def makeCC(top_k_list,A,probs):

CC = np.zeros([A.shape[0],A.shape[1]],dtype = 'float32')

for i, k in enumerate(top_k_list):

    probs_ = probs[i,:]

    labels = probs_.argsort()[-k:].tolist()

    for l in labels:

        CC[i][l] = 1

return CC

`

` def masked_accuracy(preds, labels, mask):

"""Accuracy with masking."""

# correct_prediction = tf.equal(tf.argmax(preds, 1), tf.argmax(labels, 1))

toplist = gettoplist(labels)

CC = makeCC(toplist,labels,preds)

correct_prediction = tf.equal(CC, labels)

c = np.empty((correct_prediction.shape[0],), dtype=np.bool)

for i in range(correct_prediction.shape[0]):

    if False in correct_prediction[i]:

        c[i] = False

    else:

        c[i] = True

accuracy_all = tf.cast(c, tf.float32)

mask = tf.cast(mask, dtype=tf.float32)

mask /= tf.reduce_mean(mask)

accuracy_all *= mask

return tf.reduce_mean(accuracy_all)

`

tkipf commented 5 years ago

You should threshold the output of the sigmoid (preds) before comparing it to the binary labels.

On Wed, May 29, 2019 at 4:48 AM Chengmeng94 notifications@github.com wrote:

It seems that correct_prediction only have 'false's because of preds is always decimal while label is 0 or 1(for binary classification)?

I think that for a node, in a multi-label task, all labels should be predicted correctly. The following is the code I modified, welcome to discuss.

` def gettoplist(ytest):

ytest = [[] for in range(ytest.shape[0])]

cy = sparse.coo_matrix(ytest)

for i, j in zip(cy.row, cy.col):

y_test[i].append(j)

toplist = [len(l) for l in y_test]

return toplist

`

` def makeCC(top_k_list,A,probs):

CC = np.zeros([A.shape[0],A.shape[1]],dtype = 'float32')

for i, k in enumerate(top_k_list):

probs_ = probs[i,:]

labels = probs_.argsort()[-k:].tolist()

for l in labels:

    CC[i][l] = 1

return CC

`

` def masked_accuracy(preds, labels, mask):

"""Accuracy with masking."""

correct_prediction = tf.equal(tf.argmax(preds, 1), tf.argmax(labels, 1))

toplist = gettoplist(labels)

CC = makeCC(toplist,labels,preds)

correct_prediction = tf.equal(CC, labels)

c = np.empty((correct_prediction.shape[0],), dtype=np.bool)

for i in range(correct_prediction.shape[0]):

if False in correct_prediction[i]:

    c[i] = False

else:

    c[i] = True

accuracy_all = tf.cast(c, tf.float32)

mask = tf.cast(mask, dtype=tf.float32)

mask /= tf.reduce_mean(mask)

accuracy_all *= mask

return tf.reduce_mean(accuracy_all)

`

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/tkipf/gcn/issues/119?email_source=notifications&email_token=ABYBYYG4IANWZC4UJIXEX73PXZUPJA5CNFSM4HQH6KH2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGODWPCFBI#issuecomment-496902789, or mute the thread https://github.com/notifications/unsubscribe-auth/ABYBYYE7TVSASD4VVSSL53LPXZUPJANCNFSM4HQH6KHQ .

monk1337 commented 4 years ago

You can use this, if you are dealing with multi-label,

class Optimizers(object):

    @staticmethod
    def multilabel_optimizer(logits, ground_truth, learning_rate):

        cross_entropy    = tf.nn.sigmoid_cross_entropy_with_logits(logits = logits, 
                                                                   labels = tf.cast(ground_truth,
                                                                                    tf.float32))
        loss             = tf.reduce_mean(tf.reduce_sum(cross_entropy, axis=1))
        optimizer        = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(loss)

        logits_prob      = tf.nn.sigmoid(logits, name = 'prob')
        predictions      = tf.cast(tf.sigmoid(logits) > 0.5, tf.int32,name='predictions')

        return {
                'loss'      : loss, 
                'optimizer' : optimizer, 
                'prediction': predictions, 
                'log_prob'  : logits_prob
                }

For more detail and models for multi-label classification, check this framework

https://github.com/monk1337/MultiLab