jl749 / knowledge-distillation-pytorch

A PyTorch implementation for exploring deep and shallow knowledge distillation (KD) experiments with flexibility
MIT License
0 stars 0 forks source link

cross-entropy (log loss) & KL Divergence #5

Open jl749 opened 2 years ago

jl749 commented 2 years ago

how good or bad are the predicted probabilities??

low probability --> high penalty -log(1.0) = 0 -log(0.8) = 0.22314 -log(0.6) = 0.51082

y = -log(x) image

binary cross entropy (only 2 classes) image

entropy (log-likelihood)

is a measure of the uncertainty associated with a given distribution

if every balls in a box is green you have 0 uncertainty to get a red ball (0 entropy)

what if half of the balls are red and the other half blue?

image

image if red:blue ratio is 20:80

H(q)=-(0.2log(0.2)+0.8log(0.8))=0.5

higher the entropy harder to predict

cross-entropy

cross entorpy between two distributions ... image

If we, somewhat miraculously, match p(y) to q(y) perfectly, the computed values for both cross-entropy and entropy will match as well.

Since this is likely never happening, cross-entropy will have a BIGGER value than the entropy computed on the true distribution.

e.g. red, green, blue (probability) = 0.8, 0.1, 0.1 predicted probability = 0.2, 0.2, 0.6 image

Kullback-Leibler Divergence (KL Divergence)

measure of dissimilarity between two distribution difference between (cross-entropy and entropy) image

jl749 commented 2 years ago

torch.nn.CrossEntropyLoss VS torch.nn.NLLLoss

same operation but different approaches

CrossEntropy contains both LogSoftmax and NLLLoss (more descriptive)

NLLLoss on the other hand takes logit output of F.log_softmax() (more imperative)

jl749 commented 2 years ago

conventional CrossEntropy

image https://github.com/jl749/knowledge-distillation-pytorch/blob/4a7c13d090554961f04b1b5f4ecc324cdca94d66/%23playground/torch_cee_imperatively.py#L11-L14

torch.nn.NLLLose

takes input, target; target must be 1d tensor directing to the max indexes e.g. input = [[0.25, 0.25, 0.5], [0.1, 0.2, 0.7]], target = [2, 2] https://github.com/jl749/knowledge-distillation-pytorch/blob/4a7c13d090554961f04b1b5f4ecc324cdca94d66/%23playground/torch_cee_imperatively.py#L17-L26

torch.nn.CrossEntropyLoss

image torch.nn.CrossEntropyLoss consists of nn.LogSoftmax and nn.NLLLoss https://github.com/jl749/knowledge-distillation-pytorch/blob/4a7c13d090554961f04b1b5f4ecc324cdca94d66/%23playground/torch_cee_imperatively.py#L29-L35) (see the implementation, equations will make more sense

wait, why log softmax to calculate cross entropy err?

image https://datascience.stackexchange.com/questions/40714/what-is-the-advantage-of-using-log-softmax-instead-of-softmax https://stats.stackexchange.com/questions/436766/cross-entropy-with-log-softmax-activation

TODO