ShaneTian / Att-Induction

Attention-based Induction Networks for Few-Shot Text Classification
Apache License 2.0
45 stars 7 forks source link

Question on getting prediction from model #6

Closed starnicks1 closed 3 years ago

starnicks1 commented 3 years ago

I have tried prototypical networks on my data. It looks to be working fine as the eval loss is decreasing consistently. But do you have any piece of code to get a prediction from the model on the new text sentence? For example, I want to pass a new text to the model to classify and predict the label, as we do with a normal classifier.

ShaneTian commented 3 years ago

Sorry, I did not implement the code about inference, because this is the corresponding code of the article, mainly for some metrics.

But I think it is very easy. You just need to create support, support_mask, query and query_mask before inference, and use model forward to get the relation_score and predict_label.

model forward: https://github.com/ShaneTian/Att-Induction/blob/9ae11ee30485181b9014b15ac927eb8b2a4170be/src/train.py#L471 You can imitate this to create tensors: https://github.com/ShaneTian/Att-Induction/blob/9ae11ee30485181b9014b15ac927eb8b2a4170be/src/data_loader.py#L167

starnicks1 commented 3 years ago

Thanks for the answer. You reported accuracy for MAML as well. But I could not find the code for MAML in the repository.

ShaneTian commented 3 years ago

Sorry, I didn’t report anything about MAML. Maybe you see it in other paper?

starnicks1 commented 3 years ago

Thanks Shane,

There is one more doubt. For relation network and Induction network, loss remains constant and accuracy does not increase at all and remains at 20%. I have a dataset of around 200 samples with 20 classes and 10 examples for each. What could be the problem? For matching networks it goes to around 40% and for prototypical networks it goes as high as 80%.

On Fri, 8 Jan 2021 at 21:59, ShaneTian notifications@github.com wrote:

Sorry, I didn’t report anything about MAML. Maybe you see it in other paper?

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/ShaneTian/Att-Induction/issues/6#issuecomment-756854382, or unsubscribe https://github.com/notifications/unsubscribe-auth/AI6B6VFR3JZN6C2B5NLJXGTSY4XILANCNFSM4VOA73EA .

ShaneTian commented 3 years ago

I dont know what happened until you offer more details about Relation Networks and Induction Networks. Maybe you can print more logs to debug.

In the fact, for a small dataset, I think the Prototypical Networks is more appropriate. It is easier to learn because it has fewer parameters.