thuml / Universal-Domain-Adaptation

Code release for Universal Domain Adaptation(CVPR 2019)
272 stars 52 forks source link

Question about your code #5

Closed zhangzp9970 closed 4 years ago

zhangzp9970 commented 4 years ago

Hello, after reading the paper and the code, I have a few questions about them. Wish you could answer them for me :) The UAN model proposed in the paper looks like this

TIM图片20200214122712. In my view, the feature z extracted from F is loaded into G, D, D' separately. However, the model in the code is defined as follow:

def forward(self, x):
        f = self.feature_extractor(x)
        f, _, __, y = self.classifier(f)
        d = self.discriminator(_)
        d_0 = self.discriminator_separate(_)
        return y, d, d_0

The feature f goes into classifier first but the output of the first layer of the classifier goes into the discriminator? Why?

Besides, the training of the network can be seen as a minimax game of equation (4). However, the code just added all the loss up loss = ce + adv_loss + adv_loss_separate. How to understand this?

Thanks

nmakes commented 4 years ago

@zhangzp9970 : There is a Gradient Reversal Layer in the discriminator (Line 157, Here), so there will be an adversarial loss between F and D.

I'm not sure about inputs to the discriminators though.

zhangzp9970 commented 4 years ago

@nmakes Thanks for answer! I have found that in paper Domain-Adversarial Training of Neural Networks the author mentioned that The adaptation architecture is identical to Tzeng et al. (2014) which has "a 2-layer domain classifier (x->1024->1024->2) is attached to the 256-dimensional bottleneck of fc7." And in the paper Tzeng et,al. Deep domainconfusion: Maximizing for domain invariance the author chose different layers of AlexNet(fc6,fc7,fc8) and different dimensions of the bottleneck(from 64 to 4096). He found that fc7 layer and 256 dimension would be a reasonable choice, since it got the lowest MMD value. So, I think this is why features of 256 dimension is put into the adversarial network.