SongW-SW / TENT

18 stars 0 forks source link

Meta-test classification method #4

Closed ntkien1904 closed 1 year ago

ntkien1904 commented 1 year ago

Dear author,

Thank you for the research work and sharing your code.

I have a question about the method for meta-test classification. Regarding to the paper, the model uses query matching method to match query node embedding with class prototype, but in your code it trains a classifier on support nodes embeddings.

Can you provide the code that follows the paper method? Thank you so much!

if mode=='valid' or mode=='test' or (mode=='train' and epoch%250==249):
    support_features = l2_normalize(emb_features[pos_node_idx].detach().cpu()).numpy()
    query_features = l2_normalize(emb_features[target_idx].detach().cpu()).numpy()

    support_labels=torch.zeros(N*K,dtype=torch.long)
    for i in range(N):
        support_labels[i * K:(i + 1) * K] = i

    query_labels=torch.zeros(N*Q,dtype=torch.long)
    for i in range(N):
        query_labels[i * Q:(i + 1) * Q] = i

    clf = LogisticRegression(penalty='l2',
                             random_state=0,
                             C=1.0,
                             solver='lbfgs',
                             max_iter=1000,
                             multi_class='multinomial')
    clf.fit(support_features, support_labels.numpy())
    query_ys_pred = clf.predict(query_features)

    acc_train = metrics.accuracy_score(query_labels, query_ys_pred)
SongW-SW commented 1 year ago

Thank you for your interest in our work!

In particular, we are still using only the information within the support set, except that the process is slightly parameterized by a linear layer in each meta-task. We use such a trick to improve the convergence speed since it is more robust to outliers, thus reducing the number of training epochs. Moreover, it can extract more information within the support set, such that we can use a small hidden size of GNNs to further improve the training speed, otherwise, it can be time-consuming. Nonetheless, since it is only used during inference, it will not affect the result during training.