guojiajeremy / ReContrast

28 stars 8 forks source link

About Forward Code #2

Open liwei33660 opened 11 months ago

liwei33660 commented 11 months ago

Thank you for your excellent work, I am having some trouble reading your code and would like to have your answers! def forward(self, x): en = self.encoder(x) with torch.no_grad(): en_freeze = self.encoder_freeze(x) en_2 = [torch.cat([a, b], dim=0) for a, b in zip(en, en_freeze)] de = self.decoder(self.bottleneck(en_2)) de = [a.chunk(dim=0, chunks=2) for a in de] de = [de[0][0], de[1][0], de[2][0], de[3][1], de[4][1], de[5][1]] return en_freeze + en, de

This code is located in the models/recontrast.py on line 62 and shows the forward processing of ReConstrast. I'd like to have a better idea of how the process works. 1) Why is it necessary to cat en and en_freeze and feed them to the decoder? I think that en_freeze should not need to go through the decoder. 2) As well, what does de = [de[0][0], de[1][0], de[2][0], de[3][1], de[4][1], de[5][1]] stand for? 3) In calculating loss, the de[ :3] and de[3:] stand for? (code:loss=global_cosine_hm(en[:3], de[:3], alpha=alpha, factor=0.) / 2 + global_cosine_hm(en[3:], de[3:], alpha=alpha, factor=0.) / 2.)

Thanks again for your excellent work, it really inspires me a lot.

guojiajeremy commented 10 months ago

Sorry for responding late!

  1. As in the paper, the decoder and bottleneck take the features of the domain-specific encoder to reconstruct the frozen encoder; vice versa, they simultaneously take the features of the frozen encoder to reconstruct the domain-specific encoder, building a cross-reconstruction paradigm. So they both go through the decoder.
  2. It is a little tricky. In decoder, one group of 1x1Conv is used to reconstruct domain-specific encoder, and another group of 1x1Conv is used to reconstruct frozen encoder. After chunking, de[0][0], de[1][0], de[2][0] is to reconstruct frozen encoder. de[3][1], de[4][1], de[5][1] is to reconstruct domain-specific encoder.
  3. Following the second answer, de[ :3] and de[3:] reconstruct frozen and domain-specifc encoder, respectively.

Sorry for the ambiguity. The paper is still under review. We will try to make the public code more readable after accepting.

liwei33660 commented 10 months ago

Thank you for your reply, I would like to double-check your explanation to make sure I understand it correctly: ReContrast acts as a kind of cross-distillation, computing two losses separately: 1. the features of the frozen branch encoder and the features of the domain-specific branch decoder, and 2. the features of the domain-specific branch encoder and the features of the frozen branch decoder. where the domain-specific branch decoder and the frozen branch decoder are the same decoder.

It is right?

guojiajeremy commented 10 months ago

Thank you for your reply, I would like to double-check your explanation to make sure I understand it correctly: ReContrast acts as a kind of cross-distillation, computing two losses separately: 1. the features of the frozen branch encoder and the features of the domain-specific branch decoder, and 2. the features of the domain-specific branch encoder and the features of the frozen branch decoder. where the domain-specific branch decoder and the frozen branch decoder are the same decoder.

It is right?

Exactly! Except the trivial 1x1 Convs after decoder.

liwei33660 commented 10 months ago

Thank you. But I also have trouble with loss: loss=global_cosine_hm(en[:3], de[:3], alpha=alpha, factor=0.) / 2 + global_cosine_hm(en[3:], de[3:], alpha=alpha, factor=0.) / 2

you say de[:3] and de[3:] are reconstruct frozen and domain-specifc encoder, respectively. As forword(self,x) return,en[:3] is en_freeze and en[3:] is en.

So, in the code, you calculate loss between: 1. the features of the frozen branch encoder (en[:3]) and the features of the frozen branch decoder (de[:3]), and 2. the features of the domain-specific branch encoder (en[3:]) and the features of the domain-specific branch decoder (de[3:]).

Are there some small mistakes in the loss code?

guojiajeremy commented 10 months ago

Thank you. But I also have trouble with loss: loss=global_cosine_hm(en[:3], de[:3], alpha=alpha, factor=0.) / 2 + global_cosine_hm(en[3:], de[3:], alpha=alpha, factor=0.) / 2

you say de[:3] and de[3:] are reconstruct frozen and domain-specifc encoder, respectively. As forword(self,x) return,en[:3] is en_freeze and en[3:] is en.

So, in the code, you calculate loss between: 1. the features of the frozen branch encoder (en[:3]) and the features of the frozen branch decoder (de[:3]), and 2. the features of the domain-specific branch encoder (en[3:]) and the features of the domain-specific branch decoder (de[3:]).

Are there some small mistakes in the loss code?

Actually, de[:3] is the decoder feature with the input of domain-specific encoder. de[3:] is the deocder with the input of frozen encoder. So it is cross-reconstruction.

liwei33660 commented 10 months ago

Thank you very much for your patience, I misread the word “reconstruct” and mistook it for “stand for”. Good luck with the acceptance of your work!

guojiajeremy commented 10 months ago

Thank you very much for your patience, I misread the word “reconstruct” and mistook it for “stand for”. Good luck with the acceptance of your work!

My pleasure.

FYI, this version suffers from training instability (e.g. loss spikes) with different random seeds, which can be attributed to the denominator epsilon of BatchNorm and Adam optimizer. We have made some adaptions to handle this instability, which also unified the BN mode of each category.

We will update the code ASAP after acceptance (or maybe after rejection).