frederick0329 / TracIn

Implementation of Estimating Training Data Influence by Tracing Gradient Descent (NeurIPS 2020)
Apache License 2.0
219 stars 15 forks source link

Test TracIn's effectiveness in text classification #8

Open yananchen1989 opened 3 years ago

yananchen1989 commented 3 years ago

Hello,

I adopt the code from https://github.com/frederick0329/TracIn/blob/master/imagenet/resnet50_imagenet_proponents_opponents.ipynb to text classification.

The primary goal of my task is to rank the training samples based on their positive or negative impacts on the clean validation set. The core metrics can be accuracy or cross entropy loss for my task. Quite straightforward. Where the training samples could be 100-200 and validation set contains no more than 100 samples. This is a low-data regime.

Validation set is of no error. It Is clean.

Labels include politics business tech entertainment etc. Just a public news topic classification task: AG NEWS.

As for the classifier, similar to your resnet in the image example, I am using CMLM from tensorflow hub and vectorize all samples to 1024 sentence embeddings. Therefore the classifier is quite simple: a single layer network.

here is my implementation https://github.com/yananchen1989/topic_classification_augmentation/blob/main/cmlm_proponents_opponents.py

I use AUC in the last, to test the effectiveness: high auc indicate that samples of no labelling noise get higher influence score, while samples wrongly labelled, get lower, negative score. However, the auc is 0.55. Quite woeful.

I am not sure if there is a bug in my implementation, or I have not using TracIn in a appropriate manner.

yananchen1989 commented 3 years ago

I intentionally fabricate some samples whose labels are incorrect, df_noise. It can be regarded as honeypot to test TracIn's ability to score them. It is anticipated that TracIn should give them significantly negative influence scores, and on the other hand, TracIn should give higher positive scores to samples from df_train since it is absolutely clean.

I assume that df_noise should acts as the opponents, with regard to the df_valid on the whole, which is also a clean set.

yananchen1989 commented 3 years ago

For the checkpoints picking, currently, I pick those checkpoints which reduce the loss on validation set obviously. I set the total epochs to 50. and checkpoints from early epochs are more likely to be saved, because reduction in loss in previous epochs are more likely and tangible.

Generally, I would collect 10-15 checkpoints for following loss calculation.

Jiahaou commented 2 years ago

请问源代码中的ckpt文件在哪里下载