LTH14 / targeted-supcon

A PyTorch implementation of the paper Targeted Supervised Contrastive Learning for Long-tailed Recognition
MIT License
93 stars 13 forks source link

tsc_cifar_loss #15

Closed Haroon-Wahab closed 8 months ago

Haroon-Wahab commented 8 months ago

Hi,

In tsc_cifar_loss.py, there is a mistake in the code at the end where you are masking out the targets as anchors so that external summation in equation (3) could be computed.

if target_mask is not None:

mask out the loss with target as anchor

        target_mask = target_mask.repeat(anchor_count)
        mean_log_prob_pos = mean_log_prob_pos * target_mask

I think, here for masking out the target as anchor, "1 - target_mask" should be multiplied with mean_log_prob_pos.

Thanks!

LTH14 commented 8 months ago

Thanks for your interest. The only_t argument is deprecated, so target_mask is always set to None in the main script and it is not used here.

Haroon-Wahab commented 8 months ago

Thanks for your interest. The only_t argument is deprecated, so target_mask is always set to None in the main script and it is not used here.

Thanks @LTH14 for your response. If target_mask is always set 'None' in the main script. i can see target_index is set to 'None' as well. Given both of them are set to 'None', then where have u implemented the second log term in equation (3), for the anchor and its corresponding class target. The KCL part is there where you are randomly selecting 'k' positive ids from the list of all positives ids against an anchor in the batch.

Also, I can not find implementation of moving average based class centres computation and then assignment of target using hungarian algo. it seems you have hardwired the target classes randomly to cifar10 and cifar100 classes. I can see its implemented for Moco, but all my queries are regarding cifar10 and cifar100 code.

LTH14 commented 8 months ago

For TSC, we set args.use_target=True, so it will go to these two lines: https://github.com/LTH14/targeted-supcon/blob/master/cifar_dirty/main_supcon_imba.py#L441-L442`

LTH14 commented 8 months ago

For CIFAR10 and CIFAR100, there is no need to perform moving average-based class center computation, as mentioned in section 3.2. This is because the number of classes is smaller than the feature dimension, and thus the targets are all symmetric: the distances between any two targets are the same. Therefore, any assignment is equivalent for CIFAR10 and CIFAR100.

Haroon-Wahab commented 8 months ago

For TSC, we set args.use_target=True, so it will go to these two lines: https://github.com/LTH14/targeted-supcon/blob/master/cifar_dirty/main_supcon_imba.py#L441-L442`

I appreciate your response. Repeated target vectors are appended to each multi-viewed batch in these two lines of code, which are later concatenated along dimension-1 to pass on to TSC loss. Now in TSC loss, you are considering targets just as other instances of positive class. Given that target_index and target_mask are not used so ''mask' and 'mask_copy' contain 1's for positive classes including the target positives (repeated 'opt.target_repeat' ) against an anchor. Now in this part of code: https://github.com/LTH14/targeted-supcon/blob/master/cifar_dirty/tsc_cifar_loss.py#L134-L146` KCL implementation randomly selects k positive instances to populate the numerator of eqn (3) first summation. The targets are not masked out in this selection so it seems you are including repeated targets in k-positive random selection as well. However, in paper and equation 3, k positive instances should be selected exclusive of positive targets. Additionally, the current anchor ('i') should form an exponential dot product with just its class target vector (the numerator of second log term in eqn (3)), which I could not find in the computation of loss. Anchor to K-positives (1st) and Anchor to target class (2nd) log numerator terms have manifold differences in eqn(3), Stage-1 only applies KCL to populate the embeddings and then later on second term is added as well. Then, Lambda should be used as a hyperparameter in second term. It seems the implementation in code for tsc_cifar is not exactly as in paper. In light of this comment, can you please explain this observed difference. I might be missing something here since the code is a dirty version but your reflection on this as a developer would definitely help me out. Cheers

LTH14 commented 8 months ago

The targets are not masked out in this selection so it seems you are including repeated targets in k-positive random selection as well.

This is not true. We exclude the targets when performing the k-positive selection, as shown here https://github.com/LTH14/targeted-supcon/blob/master/cifar_dirty/tsc_cifar_loss.py#L126-L131. Therefore, the targets are not included in the k-positive selection.

Haroon-Wahab commented 8 months ago

The targets are not masked out in this selection so it seems you are including repeated targets in k-positive random selection as well.

This is not true. We exclude the targets when performing the k-positive selection, as shown here https://github.com/LTH14/targeted-supcon/blob/master/cifar_dirty/tsc_cifar_loss.py#L126-L131. Therefore, the targets are not included in the k-positive selection.

So target_index in these two lines should not be commented?https://github.com/LTH14/targeted-supcon/blob/master/cifar_dirty/main_supcon_imba.py#L451-L452

LTH14 commented 8 months ago

Oh, that's true -- I might have commented it out to run some ablations and forgot to get it back when uploading the code.

LTH14 commented 8 months ago

In the CIFAR implementation, there is no lambda in eq 3, so the targets are treated as additional positives (except that it is not included in the k-selection). The exponential and log of the dot product is here https://github.com/LTH14/targeted-supcon/blob/master/cifar_dirty/tsc_cifar_loss.py#L85-L93.

Haroon-Wahab commented 8 months ago

In the CIFAR implementation, there is no lambda in eq 3, so the targets are treated as additional positives (except that it is not included in the k-selection). The exponential and log of the dot product is here https://github.com/LTH14/targeted-supcon/blob/master/cifar_dirty/tsc_cifar_loss.py#L85-L93.

Thanks a lot for clarifying this. Yes, you are computing exponential and log of the logits in the mentioned lines but later when we compute the mean_log_prob_pos here https://github.com/LTH14/targeted-supcon/blob/master/cifar_dirty/tsc_cifar_loss.py#L146, we are just doing this for k-positives using 'mask_pos_view'. So, I am now just wondering where you have included these targets as additional positives.

Secondly, would you like to share any reason for selecting the value of 'opt.target_repeat' as 25? This could be treated as a hyper-paramter. Did you try some other values as well? Surprisingly, I cannot find about this in the paper.

LTH14 commented 8 months ago

opt.target_repeat kind of serves a similar role as the Lamda -- the more repeat, the larger weight on targets. In the paper we use Lambda for simplicity in the loss formula.

LTH14 commented 8 months ago

mask_pos_view already contains all targets as 1: https://github.com/LTH14/targeted-supcon/blob/master/cifar_dirty/tsc_cifar_loss.py#L130. KCL just sets additional k elements into 1 https://github.com/LTH14/targeted-supcon/blob/master/cifar_dirty/tsc_cifar_loss.py#L143.

Haroon-Wahab commented 8 months ago

now makes sense. Thanks for your timely responses. Probably, the uncleaned version made it difficult to relate to the paper, but your responses have indirectly clarified the misconceptions created by the code.

I really appreciate your work and hope this thread will help others too in comprehending the paper. Cheers