i13abe / Triplet-Loss-for-Knowledge-Distillation

Triplet Loss for Knowledge Distillation
17 stars 8 forks source link

the pre-trained teacher model did not converge in a single network #2

Closed samo313 closed 3 years ago

samo313 commented 3 years ago

Hi, I used 'cnn_alex' model as below `
model_t = Net_teacher().to(device) model_t.load_state_dict(torch.load("cnn_alex.pkl"))

Freeze model weights

for name, param in model_t.named_parameters():
    if name == 'fc2.weight':
        break
    else:
        param.requires_grad_(False)

` then I tried to train the last layer on MNIST but It did not converge. Some results in epoch 10 are like:

Train Epoch: 10 [0/60000] Loss: 2.9213 (2.9213) Acc: 1.56% (1.56%) Train Epoch: 10 [640/60000] Loss: 2.9524 (2.9019) Acc: 4.40% (3.80%) Train Epoch: 10 [1280/60000] Loss: 2.9556 (2.9212) Acc: 4.24% (3.97%) Train Epoch: 10 [1920/60000] Loss: 3.0981 (2.9208) Acc: 4.54% (4.18%)

I want to use this network as a teacher to train a student on MNIST. Would you please help me? how can I handle this problem? thanks

i13abe commented 3 years ago

I don't know why it does not learn. But if you want to learn on MNIST, usually it works well without the pkl model file. First please try learning on MNSIT without pkl file. (you don't use "model_t.load_state_dict(torch.load("cnn_alex.pkl"))"). Please check it.

samo313 commented 3 years ago

I don't know why it does not learn. But if you want to learn on MNIST, usually it works well without the pkl model file. First please try learning on MNSIT without pkl file. (you don't use "model_t.load_state_dict(torch.load("cnn_alex.pkl"))"). Please check it.

I will check It. could you please tell me how did you train 'cnn_alex'? or share with me your 'cnn_alex' training and test code? and tell me 'cnn_alex' was trained on which dataset?

i13abe commented 3 years ago

My "cnn_alex.pkl" is trained on CIFAR10. I share the training python file in this repository of github. But it is not clear, so I am sorry. Please read it.

i13abe commented 3 years ago

cnn_cifar10.py is available now.

samo313 commented 3 years ago

cnn_cifar10.py is available now.

Here it is not available for me!

i13abe commented 3 years ago

sorry. I uploaded now

samo313 commented 3 years ago

uploaded

It's ok. thank you. When I want to use this model on another dataset like MNIST, shall I fine-tune 'cnn_alex' on the new dataset?

i13abe commented 3 years ago

You can do that.

samo313 commented 3 years ago

You can do that.

I trained my model and fine-tuned your model separately on the MNIST dataset. Unfortunately, I get 100% accuracy from the first epoch. It seems to be a problem. I realized that the scale value of the last encoding vector of both models as teachers is higher than the student vector. I think It comes from the difference between loss functions in both networks. The teacher use cross-entropy and the student in my model use nn.MarginRankingLoss without any cross-entropy as u added to your loss. What do u think?

i13abe commented 3 years ago

I trained my model and fine-tuned your model separately on the MNIST dataset. Unfortunately, I get 100% accuracy from the first epoch. It seems to be a problem. I realized that the scale value of the last encoding vector of both models as teachers is higher than the student vector. I think It comes from the difference between loss functions in both networks. The teacher use cross-entropy and the student in my model use nn.MarginRankingLoss without any cross-entropy as u added to your loss. What do u think?

The network I prepared is deeper and larger for MNIST dataset. It can achieve 100% accuracy soon. So your training is fine I think. If you want to try it again, you can prepare a small network (3 Conv and 2 fc) as a teacher on MNIST. Then, you can use a smaller network (2 fc) as a student. If you seem some problem with MNIST training, please check the loss value and check the predicted labels. At the same time, you can consider the learning rate and weight decay because loss is changed from my training. You can adjust them.

samo313 commented 3 years ago

I trained my model and fine-tuned your model separately on the MNIST dataset. Unfortunately, I get 100% accuracy from the first epoch. It seems to be a problem. I realized that the scale value of the last encoding vector of both models as teachers is higher than the student vector. I think It comes from the difference between loss functions in both networks. The teacher use cross-entropy and the student in my model use nn.MarginRankingLoss without any cross-entropy as u added to your loss. What do u think?

The network I prepared is deeper and larger for MNIST dataset. It can achieve 100% accuracy soon. So your training is fine I think. If you want to try it again, you can prepare a small network (3 Conv and 2 fc) as a teacher on MNIST. Then, you can use a smaller network (2 fc) as a student. If you seem some problem with MNIST training, please check the loss value and check the predicted labels. At the same time, you can consider the learning rate and weight decay because loss is changed from my training. You can adjust them.

It's a good tip. But my loss calculation is a bit different from yours. I use nn.MarginRankingLoss (dist_aa, dist_an, target) where dist_aa is the difference between encoding vectors of teacher and student on Anchor image and dist_an is the difference between student encoding vectors on Anchor and Negative plus a term of a combination of three encoding vectors. Then I use the accuracy like below

margin = 0.0 pred = (distaa - distan - margin).cpu().data return (pred > 0).sum()*1.0/distaa.size()[0] which computes the distance between this two dist_aa and distan so I get 100% acc. but when I use your accuracy code like below ` , predicted = out2_s.max(1) correct = (predicted == label2_s).sum().item()`

I achieve low accuracy between 9 to 19. But I don't know when I don't use any cross-entropy in my loss that accuracy (i.e. your code) is reasonable?

i13abe commented 3 years ago

Your accuracy computes just positive or negative of triplet loss. So you can get 100% accuracy. But my accuracy computes labels from 0 to 9 between 10dim output vector and ground truth labels. Accuracy is not related to the loss function. In my case, I focused on classification, the index of max of the output vector is the prediction label (_, predicted = out2_s.max(1)). So the accuracy is the number of matching between prediction and labels (correct = (predicted == label2_s).sum().item()).

samo313 commented 3 years ago

Your accuracy computes just positive or negative of triplet loss. So you can get 100% accuracy. But my accuracy computes labels from 0 to 9 between 10dim output vector and ground truth labels. Accuracy is not related to the loss function. In my case, I focused on classification, the index of max of the output vector is the prediction label (_, predicted = out2_s.max(1)). So the accuracy is the number of matching between prediction and labels (correct = (predicted == label2_s).sum().item()).

do u mean my acc calculation is wrong?

i13abe commented 3 years ago

your acc calculation is just "pred = (distaa - distan - margin).cpu().data". This means triplet loss. After this calculation, you did "(pred > 0).sum()*1.0/distaa.size()[0]" This means that tiplet loss is positive or negative. It is not related to the ground-truth labels. This is the wrong point. In your implementation, your acc denotes whether distaa is higher than distan or not. I want to ask you, do you want to implement classification acc (predict 0 to 9 label from image)??

samo313 commented 3 years ago

your acc calculation is just "pred = (distaa - distan - margin).cpu().data". This means triplet loss. After this calculation, you did "(pred > 0).sum()*1.0/distaa.size()[0]" This means that tiplet loss is positive or negative. It is not related to the ground-truth labels. This is the wrong point. In your implementation, your acc denotes whether distaa is higher than distan or not. I want to ask you, do you want to implement classification acc (predict 0 to 9 label from image)??

my aim is to calculate how the learned features in the last embedding vector of the student are similar to those learned by the teacher. I don't want to do classification.

i13abe commented 3 years ago

In that case, you should evaluate the loss value. The acc is a comparison between prediction and ground truth.

samo313 commented 3 years ago

In that case, you should evaluate the loss value. The acc is a comparison between prediction and ground truth.

What should I do if I need Acc?

i13abe commented 3 years ago

Before considering it, you should know what is correct. When you get the learned output of a student model, what results would you consider correct? Then you can consider how to calculate acc. For example, you want to evaluate the similarity between outputs of teacher and student, you can calculate MSE between them. The lower it, the better similarity. I don't know what you want to do for your research. I can not help all, but I can help only what I know. Usually, we know how to evaluate from related research. If there is no related works, you can set your evaluation. You can try it.

samo313 commented 3 years ago

Before considering it, you should know what is correct. When you get the learned output of a student model, what results would you consider correct? Then you can consider how to calculate acc. For example, you want to evaluate the similarity between outputs of teacher and student, you can calculate MSE between them. The lower it, the better similarity. I don't know what you want to do for your research. I can not help all, but I can help only what I know. Usually, we know how to evaluate from related research. If there is no related works, you can set your evaluation. You can try it.

thanks.