Open liwei33660 opened 11 months ago
Sorry for responding late!
Sorry for the ambiguity. The paper is still under review. We will try to make the public code more readable after accepting.
Thank you for your reply, I would like to double-check your explanation to make sure I understand it correctly: ReContrast acts as a kind of cross-distillation, computing two losses separately: 1. the features of the frozen branch encoder and the features of the domain-specific branch decoder, and 2. the features of the domain-specific branch encoder and the features of the frozen branch decoder. where the domain-specific branch decoder and the frozen branch decoder are the same decoder.
It is right?
Thank you for your reply, I would like to double-check your explanation to make sure I understand it correctly: ReContrast acts as a kind of cross-distillation, computing two losses separately: 1. the features of the frozen branch encoder and the features of the domain-specific branch decoder, and 2. the features of the domain-specific branch encoder and the features of the frozen branch decoder. where the domain-specific branch decoder and the frozen branch decoder are the same decoder.
It is right?
Exactly! Except the trivial 1x1 Convs after decoder.
Thank you. But I also have trouble with loss:
loss=global_cosine_hm(en[:3], de[:3], alpha=alpha, factor=0.) / 2 + global_cosine_hm(en[3:], de[3:], alpha=alpha, factor=0.) / 2
you say de[:3]
and de[3:]
are reconstruct frozen and domain-specifc encoder, respectively. As forword(self,x)
return,en[:3]
is en_freeze
and en[3:]
is en
.
So, in the code, you calculate loss between: 1. the features of the frozen branch encoder (en[:3]
) and the features of the frozen branch decoder (de[:3]
), and 2. the features of the domain-specific branch encoder (en[3:]
) and the features of the domain-specific branch decoder (de[3:]
).
Are there some small mistakes in the loss code?
Thank you. But I also have trouble with loss:
loss=global_cosine_hm(en[:3], de[:3], alpha=alpha, factor=0.) / 2 + global_cosine_hm(en[3:], de[3:], alpha=alpha, factor=0.) / 2
you say
de[:3]
andde[3:]
are reconstruct frozen and domain-specifc encoder, respectively. Asforword(self,x)
return,en[:3]
isen_freeze
anden[3:]
isen
.So, in the code, you calculate loss between: 1. the features of the frozen branch encoder (
en[:3]
) and the features of the frozen branch decoder (de[:3]
), and 2. the features of the domain-specific branch encoder (en[3:]
) and the features of the domain-specific branch decoder (de[3:]
).Are there some small mistakes in the loss code?
Actually, de[:3]
is the decoder feature with the input of domain-specific encoder. de[3:]
is the deocder with the input of frozen encoder. So it is cross-reconstruction.
Thank you very much for your patience, I misread the word “reconstruct” and mistook it for “stand for”. Good luck with the acceptance of your work!
Thank you very much for your patience, I misread the word “reconstruct” and mistook it for “stand for”. Good luck with the acceptance of your work!
My pleasure.
FYI, this version suffers from training instability (e.g. loss spikes) with different random seeds, which can be attributed to the denominator epsilon of BatchNorm and Adam optimizer. We have made some adaptions to handle this instability, which also unified the BN mode of each category.
We will update the code ASAP after acceptance (or maybe after rejection).
Thank you for your excellent work, I am having some trouble reading your code and would like to have your answers!
def forward(self, x):
en = self.encoder(x)
with torch.no_grad():
en_freeze = self.encoder_freeze(x)
en_2 = [torch.cat([a, b], dim=0) for a, b in zip(en, en_freeze)]
de = self.decoder(self.bottleneck(en_2))
de = [a.chunk(dim=0, chunks=2) for a in de]
de = [de[0][0], de[1][0], de[2][0], de[3][1], de[4][1], de[5][1]]
return en_freeze + en, de
This code is located in the models/recontrast.py on line 62 and shows the forward processing of ReConstrast. I'd like to have a better idea of how the process works. 1) Why is it necessary to cat en and en_freeze and feed them to the decoder? I think that en_freeze should not need to go through the decoder. 2) As well, what does de = [de[0][0], de[1][0], de[2][0], de[3][1], de[4][1], de[5][1]] stand for? 3) In calculating loss, the de[ :3] and de[3:] stand for? (code:
loss=global_cosine_hm(en[:3], de[:3], alpha=alpha, factor=0.) / 2 + global_cosine_hm(en[3:], de[3:], alpha=alpha, factor=0.) / 2
.)Thanks again for your excellent work, it really inspires me a lot.