Open LanXiaoPang613 opened 1 year ago
Hi,
No need to do it again.
Let me know if you have any additional questions.
I also have a question about the code implementation of the line "kldiv_loss_per_pair = weighted_t_softmax * ( jnp.log(weighted_t_softmax) - s_softmax_temp) # [n, m]". According to Eq(3) in the paper and the KL divergence equation D_KL(A||B)=sum(AlogA-AlogB), does the code implementation actually swap the order between A and B in KLD equation Eq(3)? (Since weighted_t_softmax involves the knn_similarities, weighted_t_softmax seems to be the B term in the paper but it was used as A term in the code.)
Thanks for your attention and replies.
- For the klloss function **"pairwise_klloss", why the "kldiv_loss_per_pair = weighted_t_softmax * ( jnp.log(weighted_t_softmax) - s_softmax_temp) # [n, m]".* Actually, in my understanding, for the KL divergence, D_KL(A||B)=sum(AlogA-AlogB), according to the Eq(3) of the paper, i think it should be"kldiv_loss_per_pair = weighted_t_softmax ( jnp.log(weighted_t_softmax) - inp.log(s_softmax_temp)) # [n, d]";
- Still for the klloss function **"pairwise_klloss", For the "kldiv_loss_per_example = ( jnp.power(temperature, 2) * jnp.sum(kldiv_loss_per_pair, 1)) # [n, 1]". , why we need to multiply the jnp.power(temperature, 2)** again, which is not shown in the paper. Thank for your attention and replies.
i am reproduce the result of ncr also, what the jax and flax version are you used? i am stuck in this step a weeks. thank you for your sharing.
Hi,
- log operation is already applied to logits on Line 44: s_softmax_temp = jax.nn.log_softmax(logits / temperature) # [n, d]
No need to do it again.
- This is a common practice for knowledge distillation. See the Section 2 of Distilling the Knowledge in a Neural Network, where the authors state: Since the magnitudes of the gradients produced by the soft targets scale as 1/T^2, it is important to multiply them by T^2 when using both hard and soft targets.
Let me know if you have any additional questions.
Hi, Ahmet Iscen,
I'm interested in the design of NCR. But I am confused about the formula (3). Could you clarify the purpose of the second summation term, sum_{NN_k(v_i)}? Could you provide a more detailed explanation of the NCR loss formula? Thank you very much for your help and I will be waiting for your reply here.