pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.89k stars 3.61k forks source link

How to determine unknown class /Random data using GCNconv ? #3250

Open shubham-v opened 2 years ago

shubham-v commented 2 years ago

❓ Questions & Help

I think this is more of a general question for classification I have n classes plus some random data belonging to some other class. Random data is not included in the training set. I trained the network for n classes and it is performing well. Now if an unknown class object comes in for prediction, the GCNconv predicts it as any of the n classes Which is clearly misclassified. The confidence also comes near by 0.97, which makes it difficult to filter out. And also known class object of n trained classes is also classified at the same confidence level. One idea that comes to mind after checking online is to try keeping a class that does not include any feature set of n classes i.e. sort of negative sampled class as an unknown class. But I think it can confuse the model. This I am still going through. So is there a way to tackle it that it doesn't fall to any known class classification and we can filter it out by saying, unknown class?

rusty1s commented 2 years ago

If an unknown class object comes in during evaluation, there is no way for a neural network to detect that. You have two options:

  1. Adding an unknown class label during training (which requires that you have unknown class labels in your training set)
  2. Try to keep the confidences of the NN low for unknown classes, and use a threshold to determine unknown classes.

I don't have much experience with the latter approach, but the recently added RECT_L for zero-shot learning might be of interest to you.

shubham-v commented 2 years ago

Thank you for your valuable insights I will try to apply both the option mentioned by you. Let's see how it goes.

shubham-v commented 2 years ago

hi rusty, I am relatively new correct me if these questions don't make sense to you..but this is what I understood 1-In RECT_L after getting the embedding at the end we are using logistic regression where we are giving a particular label corresponding to the embedding. so this basically mean we know which embedding belong to which class and also we know that there are only say for example 7 class in the whole process since we are giving y for logistic regression but if at the end new class will come i.e 8 it will classify it among 7 only.so how it can handle unseen classes after training the embedding only for 7 classes?

2- when we are trying to collect the semantic embedding we are taking a centroid as its embedding based on this function zs_data.y = model.get_semantic_labels(zs_data.x, zs_data.y, zs_data.train_mask) and after that, we are training a neural net by minimizing a loss for only seen label and predicting it for the unseen label. so my point is why not take the centroid for all nodes and use that in logistic regression.

3- Does this Rectl can be used for graph classification?

rusty1s commented 2 years ago
  1. Yes, you are right. The logistic regression is just an elegant way to evaluate how well the model can linearly separate classes in the embedding space which it has never seen before. However, for evaluation, all classes need to be considered, as it is the case for any supervised learning model.
  2. I'm not sure I understand. As we are not training against "unseen labels", we do not need to compute their semantic labels.
  3. I think this is possible. @Fizyhsp Do you have any thoughts on this?